[tf.data] This CL changes how the in-memory cache resource is managed, making it possible for the cache dataset to support both a) sharing of the cache across iterators and b) serialization. As a consequence, this CL enables sharing of the cache across iterators for tf.distribute and tf.data service use cases (which require serialization support).

PiperOrigin-RevId: 307469217
Change-Id: Ia4f4384752609b83cb10078ebeb20dfc6c8a2d8f
This commit is contained in:
Jiri Simsa 2020-04-20 13:51:58 -07:00 committed by TensorFlower Gardener
parent b592f87bd8
commit 7ebbab819e
12 changed files with 211 additions and 139 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "DummyMemoryCache"
visibility: HIDDEN
}

View File

@ -27,8 +27,6 @@ namespace tensorflow {
namespace grappler { namespace grappler {
namespace { namespace {
constexpr char kCacheDataset[] = "CacheDataset";
constexpr char kCacheDatasetV2[] = "CacheDatasetV2";
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration"; constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
constexpr char kShuffleDataset[] = "ShuffleDataset"; constexpr char kShuffleDataset[] = "ShuffleDataset";
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2"; constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
@ -55,10 +53,6 @@ Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster,
node.add_input(zero_node->name()); node.add_input(zero_node->name());
// set `reshuffle_each_iteration` attr // set `reshuffle_each_iteration` attr
(*node.mutable_attr())[kReshuffleEachIteration].set_b(true); (*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
} else if (node.op() == kCacheDatasetV2) {
*node.mutable_op() = kCacheDataset;
// remove `cache` input
node.mutable_input()->RemoveLast();
} }
} }

View File

@ -28,29 +28,6 @@ namespace tensorflow {
namespace grappler { namespace grappler {
namespace { namespace {
TEST(MakeStateless, Cache) {
using test::function::NDef;
GrapplerItem item;
item.graph = test::function::GDef(
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_INT64}}),
NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}),
graph_tests_utils::MakeCacheV2Node("cache", "range", "filename",
"handle")},
{});
MakeStateless optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("cache", output));
int index = graph_utils::FindGraphNodeWithName("cache", output);
EXPECT_EQ(output.node(index).op(), "CacheDataset");
EXPECT_EQ(output.node(index).input_size(), 2);
}
TEST(MakeStateless, Shuffle) { TEST(MakeStateless, Shuffle) {
using test::function::NDef; using test::function::NDef;
GrapplerItem item; GrapplerItem item;

View File

@ -48,7 +48,6 @@ constexpr char kShardId[] = "shard_id";
constexpr char kCreatedAt[] = "Created at"; constexpr char kCreatedAt[] = "Created at";
constexpr char kMemoryDatasetPrefix[] = "Memory"; constexpr char kMemoryDatasetPrefix[] = "Memory";
constexpr char kMemoryCache[] = "MemoryCache"; constexpr char kMemoryCache[] = "MemoryCache";
constexpr char kTFData[] = "tf_data";
constexpr char kCacheClaimed[] = "cache_claimed"; constexpr char kCacheClaimed[] = "cache_claimed";
constexpr char kCacheSize[] = "cache_size"; constexpr char kCacheSize[] = "cache_size";
constexpr char kCache[] = "cache"; constexpr char kCache[] = "cache";
@ -58,10 +57,10 @@ constexpr char kIndex[] = "index";
constexpr char kImpl[] = "Impl"; constexpr char kImpl[] = "Impl";
constexpr char kCacheDataset[] = "CacheDataset"; constexpr char kCacheDataset[] = "CacheDataset";
class CacheDatasetOp::FileDataset : public DatasetBase { class CacheDatasetOp::FileDatasetBase : public DatasetBase {
public: public:
explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input, FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
string filename, Env* env) string filename, Env* env)
: DatasetBase(DatasetContext(ctx)), : DatasetBase(DatasetContext(ctx)),
input_(input), input_(input),
filename_(std::move(filename)), filename_(std::move(filename)),
@ -76,7 +75,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
DCHECK_EQ(item_index_padding_size_, 7); DCHECK_EQ(item_index_padding_size_, 7);
} }
~FileDataset() override { input_->Unref(); } ~FileDatasetBase() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override { const string& prefix) const override {
@ -107,17 +106,6 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
} }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
Node* filename = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
return Status::OK();
}
const DatasetBase* const input_; const DatasetBase* const input_;
const tstring filename_; const tstring filename_;
@ -131,10 +119,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
tensor_index); tensor_index);
} }
class FileIterator : public DatasetIterator<FileDataset> { class FileIterator : public DatasetIterator<FileDatasetBase> {
public: public:
explicit FileIterator(const Params& params) explicit FileIterator(const Params& params)
: DatasetIterator<FileDataset>(params) { : DatasetIterator<FileDatasetBase>(params) {
if (params.dataset->env_ if (params.dataset->env_
->FileExists(MetaFilename(params.dataset->filename_)) ->FileExists(MetaFilename(params.dataset->filename_))
.ok()) { .ok()) {
@ -199,7 +187,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
private: private:
// FileWriterIterator passes through and caches items from the input // FileWriterIterator passes through and caches items from the input
// FileDataset. // FileDatasetBase.
// //
// This iterator is used when the cache directory is not found on disk. It // This iterator is used when the cache directory is not found on disk. It
// creates the cache directory, and passes on the underlying iterator's // creates the cache directory, and passes on the underlying iterator's
@ -214,10 +202,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
// partial cache gets flushed to disk in files with prefix // partial cache gets flushed to disk in files with prefix
// <filename>_<shard_id> where shard_id is unique for each checkpoint. // <filename>_<shard_id> where shard_id is unique for each checkpoint.
// When all elements have been produced, these shards get coalesced. // When all elements have been produced, these shards get coalesced.
class FileWriterIterator : public DatasetIterator<FileDataset> { class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
public: public:
explicit FileWriterIterator(const Params& params) explicit FileWriterIterator(const Params& params)
: DatasetIterator<FileDataset>(params), : DatasetIterator<FileDatasetBase>(params),
cur_index_(0), cur_index_(0),
shard_id_(0), shard_id_(0),
filename_( filename_(
@ -483,10 +471,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
bool iteration_completed_ TF_GUARDED_BY(mu_); bool iteration_completed_ TF_GUARDED_BY(mu_);
}; // FileWriterIterator }; // FileWriterIterator
class FileReaderIterator : public DatasetIterator<FileDataset> { class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
public: public:
explicit FileReaderIterator(const Params& params) explicit FileReaderIterator(const Params& params)
: DatasetIterator<FileDataset>(params), : DatasetIterator<FileDatasetBase>(params),
cur_index_(0), cur_index_(0),
reader_(dataset()->env_, dataset()->filename_), reader_(dataset()->env_, dataset()->filename_),
iterator_restored_(false) {} iterator_restored_(false) {}
@ -606,14 +594,31 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
static constexpr size_t kMaxItems = 10000000; // 10 million static constexpr size_t kMaxItems = 10000000; // 10 million
const size_t item_index_padding_size_; const size_t item_index_padding_size_;
const string tensor_format_string_; const string tensor_format_string_;
}; // FileDataset }; // FileDatasetBase
class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset { class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
public:
using FileDatasetBase::FileDatasetBase;
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
Node* filename = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
return Status::OK();
}
};
class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
public: public:
explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input, explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
string filename, Env* env, string filename, Env* env,
const Tensor& resource_handle) const Tensor& resource_handle)
: FileDataset(ctx, input, filename, env), : FileDatasetBase(ctx, input, filename, env),
resource_handle_(resource_handle) {} resource_handle_(resource_handle) {}
protected: protected:
@ -686,20 +691,15 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
} // namespace } // namespace
class CacheDatasetOp::MemoryDataset : public DatasetBase { class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
public: public:
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input, explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache) MemoryCache* cache)
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) { : DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
input_->Ref(); input_->Ref();
} }
~MemoryDataset() override { ~MemoryDatasetBase() override { input_->Unref(); }
input_->Unref();
if (cache_) {
cache_->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override { const string& prefix) const override {
@ -732,44 +732,13 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
} }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* filename_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_node, filename_node}, output));
return Status::OK();
}
class MemoryIterator : public DatasetIterator<MemoryDataset> {
public: public:
explicit MemoryIterator(const Params& params, MemoryCache* cache) explicit MemoryIterator(const Params& params, MemoryCache* cache)
: DatasetIterator<MemoryDataset>(params), cache_(cache) {} : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
~MemoryIterator() override {
if (dataset()->cache_ == nullptr) {
cache_->Unref();
}
}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_); mutex_lock l(mu_);
if (cache_ == nullptr) {
// Use the resource manager in the iterator context to get / create
// a cache.
ResourceMgr* mgr = ctx->resource_mgr();
const string name = strings::StrCat(
prefix(), name_utils::kDelimiter, dataset()->node_name(),
name_utils::kDelimiter, kMemoryCache);
TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>(
kTFData, name, &cache_, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
}
InitializeIterator(); InitializeIterator();
return iterator_->Initialize(ctx); return iterator_->Initialize(ctx);
} }
@ -817,10 +786,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
} }
private: private:
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> { class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
public: public:
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache) explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
: DatasetIterator<MemoryDataset>(params), cache_(cache) {} : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
~MemoryWriterIterator() override { ~MemoryWriterIterator() override {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -900,12 +869,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_); std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
}; // MemoryWriterIterator }; // MemoryWriterIterator
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> { class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
public: public:
explicit MemoryReaderIterator(const Params& params, MemoryCache* cache) explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
: DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) { : DatasetIterator<MemoryDatasetBase>(params),
CHECK(cache); cache_(cache),
} index_(0) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
// The memory allocated for the cache is owned by the parent // The memory allocated for the cache is owned by the parent
@ -988,19 +957,80 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
}; // MemoryIterator }; // MemoryIterator
const DatasetBase* const input_; const DatasetBase* const input_;
MemoryCache* cache_ = nullptr; MemoryCache* const cache_;
}; // MemoryDataset }; // MemoryDatasetBase
class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset { // This version of memory dataset has an exclusive ownership of the memory cache
// resource. It supports sharing of the cache across different iterations of the
// `repeat` transformation but not across different iterators.
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
public: public:
explicit MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input, MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache, MemoryCache* cache, const ResourceHandle& resource_handle)
std::unique_ptr<OwnedResourceHandle> handle) : MemoryDatasetBase(ctx, input, cache),
: MemoryDataset(ctx, input, cache), handle_(std::move(handle)) {} resource_handle_(resource_handle) {
cleanup_ = [this, mgr = ctx->resource_manager()]() {
DCHECK(cache_->RefCountIsOne());
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
};
}
Status CheckExternalState() const override { ~MemoryDataset() override {
return errors::FailedPrecondition(DebugString(), cache_->Unref();
" depends on memory cache resource."); cleanup_();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* filename_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_node, filename_node}, output));
return Status::OK();
}
private:
std::function<void()> cleanup_;
const ResourceHandle resource_handle_;
};
// This version of memory dataset has a shared ownership of the memory cache
// resource. It supports sharing of the cache across different iterations of
// the `repeat` transformation and also across different iterators.
class CacheDatasetOp::MemoryDatasetV2
: public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache, const ResourceHandle& resource_handle)
: MemoryDatasetBase(ctx, input, cache),
resource_handle_(std::move(resource_handle)) {
cleanup_ = [this, mgr = ctx->resource_manager()]() {
if (cache_->RefCountIsOne()) {
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
resource_handle_.name());
if (!s.ok()) {
if (errors::IsNotFound(s)) {
// This is a bening race resulting from concurrent deletion.
VLOG(1) << "Failed to delete cache resource: " << s.ToString();
} else {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}
}
};
}
~MemoryDatasetV2() override {
cache_->Unref();
cleanup_();
} }
protected: protected:
@ -1013,7 +1043,7 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node)); TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
Node* resource_handle_node = nullptr; Node* resource_handle_node = nullptr;
Tensor handle(DT_RESOURCE, TensorShape({})); Tensor handle(DT_RESOURCE, TensorShape({}));
handle.scalar<ResourceHandle>()() = handle_->handle(); handle.scalar<ResourceHandle>()() = resource_handle_;
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node)); TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
TF_RETURN_IF_ERROR(b->AddDataset( TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_node, filename_node, resource_handle_node}, output)); this, {input_node, filename_node, resource_handle_node}, output));
@ -1021,7 +1051,8 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
} }
private: private:
std::unique_ptr<OwnedResourceHandle> handle_; std::function<void()> cleanup_;
const ResourceHandle resource_handle_;
}; };
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx) CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
@ -1033,22 +1064,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
// Parse out the filenames tensor. // Parse out the filenames tensor.
tstring filename; tstring filename;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename)); OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
if (filename.empty()) { if (filename.empty()) {
static std::atomic<int64> resource_id_counter(0);
const string& container = ctx->resource_manager()->default_container();
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
resource_id_counter.fetch_add(1));
if (op_version_ == 2) { if (op_version_ == 2) {
MemoryCache* cache = nullptr; MemoryCache* cache = nullptr;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &cache)); auto handle = HandleFromInput(ctx, 2);
Status s = ctx->resource_manager()->Lookup<MemoryCache>(
// Create a fresh handle for the resource because the input handle can handle.container(), handle.name(), &cache);
// become invalid after this op executes. if (errors::IsNotFound(s)) {
std::unique_ptr<OwnedResourceHandle> handle; OP_REQUIRES_OK(ctx,
OP_REQUIRES_OK( ctx->resource_manager()->LookupOrCreate<MemoryCache>(
ctx, OwnedResourceHandle::Create(ctx, cache, kMemoryCache, &handle)); container, name, &cache, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
// Ownership of cache is transferred onto `MemoryDatasetV2`. // Ownership of cache is transferred onto `MemoryDatasetV2`.
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle)); *output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
} else { } else {
*output = new MemoryDataset(ctx, input, /*cache=*/nullptr); MemoryCache* cache;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
container, name, &cache, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
// Ownership of cache is transferred onto `MemoryDataset`.
*output = new MemoryDataset(ctx, input, cache, handle);
} }
} else { } else {
if (op_version_ == 2) { if (op_version_ == 2) {

View File

@ -22,8 +22,8 @@ namespace data {
class CacheDatasetOp : public UnaryDatasetOpKernel { class CacheDatasetOp : public UnaryDatasetOpKernel {
public: public:
class FileDataset; class FileDatasetBase;
class MemoryDataset; class MemoryDatasetBase;
static constexpr const char* const kDatasetType = "Cache"; static constexpr const char* const kDatasetType = "Cache";
static constexpr const char* const kInputDataset = "input_dataset"; static constexpr const char* const kInputDataset = "input_dataset";
@ -38,10 +38,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
DatasetBase** output) override; DatasetBase** output) override;
private: private:
class FileDataset;
class FileDatasetV2; class FileDatasetV2;
class MemoryDataset;
class MemoryDatasetV2; class MemoryDatasetV2;
int op_version_; const int op_version_;
}; };
} // namespace data } // namespace data

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/random_distributions.h"
@ -26,7 +27,7 @@ namespace tensorflow {
namespace data { namespace data {
namespace { namespace {
const char kMemoryCache[] = "MemoryCache"; constexpr char kMemoryCache[] = "MemoryCache";
} // namespace } // namespace
@ -82,10 +83,11 @@ Status AnonymousMemoryCacheHandleOp::CreateResource(
void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) { void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) {
const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0); const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
// The resource is guaranteed to exist because the variant tensor wrapping the // The resource might have been already deleted by the dataset.
// deleter is provided as an unused input to this op, which guarantees that it Status s = ctx->resource_manager()->Delete(handle);
// has not run yet. if (!errors::IsNotFound(s)) {
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete(handle)); OP_REQUIRES_OK(ctx, s);
}
} }
namespace { namespace {
@ -96,6 +98,9 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousMemoryCache").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU),
DeleteMemoryCacheOp); DeleteMemoryCacheOp);
REGISTER_KERNEL_BUILDER(Name("DummyMemoryCache").Device(DEVICE_CPU),
DummyResourceOp<MemoryCache>);
} // namespace } // namespace
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -286,6 +286,28 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
std::function<void(std::function<void()>)> RunnerWithMaxParallelism( std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
std::function<void(std::function<void()>)> runner, int max_parallelism); std::function<void(std::function<void()>)> runner, int max_parallelism);
// Op for creating a typed dummy resource.
//
// This op is used to provide a resource "placeholder" for ops such as
// `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
// Originally, the lifetime of the resources passed into these ops was managed
// externally. After the implementation changed to manage the lifetime of the
// resources (including creation) by the ops themselves, the resource input is
// only needed to pass a resource handle through graph rewrites. When they are
// invoked from user code, the implementation passes in a dummy resource.
template <typename ResourceType>
class DummyResourceOp : public OpKernel {
public:
explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
Tensor* tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
ctx, /*container=*/"", /*name=*/"dummy_resource");
}
};
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -504,6 +504,13 @@ REGISTER_OP("DeleteMemoryCache")
.Input("deleter: variant") .Input("deleter: variant")
.SetShapeFn(shape_inference::NoOutputs); .SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("DummyMemoryCache")
.Output("handle: resource")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("CacheDataset") REGISTER_OP("CacheDataset")
.Input("input_dataset: variant") .Input("input_dataset: variant")
.Input("filename: string") .Input("filename: string")

View File

@ -244,8 +244,6 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset_ops.Dataset.from_tensor_slices(components).repeat(0)) dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
cache_dataset = repeat_dataset.cache() cache_dataset = repeat_dataset.cache()
# Create initialization ops for iterators without and with
# caching, respectively.
self.assertDatasetProduces(cache_dataset, expected_output=[]) self.assertDatasetProduces(cache_dataset, expected_output=[])
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())

View File

@ -3526,6 +3526,8 @@ class RangeDataset(DatasetSource):
return self._structure return self._structure
# This can be deleted after the forward compatibility window for switching
# to using dummy resource expires on 5/20.
class _MemoryCacheDeleter(object): class _MemoryCacheDeleter(object):
"""An object which cleans up an anonymous memory cache resource. """An object which cleans up an anonymous memory cache resource.
@ -3552,15 +3554,20 @@ class _MemoryCacheDeleter(object):
handle=self._handle, deleter=self._deleter) handle=self._handle, deleter=self._deleter)
# This can be deleted after the forward compatibility window for switching
# to using dummy resource expires on 5/20.
class _MemoryCache(object): class _MemoryCache(object):
"""Represents a memory cache resource.""" """Represents a memory cache resource."""
def __init__(self): def __init__(self):
super(_MemoryCache, self).__init__() super(_MemoryCache, self).__init__()
self._device = context.context().device_name if compat.forward_compatible(2020, 5, 20):
self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache()) self._handle = gen_dataset_ops.dummy_memory_cache()
self._resource_deleter = _MemoryCacheDeleter( else:
handle=self._handle, device=self._device, deleter=self._deleter) self._device = context.context().device_name
self._handle, self._deleter = gen_dataset_ops.anonymous_memory_cache()
self._resource_deleter = _MemoryCacheDeleter(
handle=self._handle, device=self._device, deleter=self._deleter)
@property @property
def handle(self): def handle(self):

View File

@ -1176,6 +1176,10 @@ tf_module {
name: "DrawBoundingBoxesV2" name: "DrawBoundingBoxesV2"
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "DummyMemoryCache"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "DynamicPartition" name: "DynamicPartition"
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1176,6 +1176,10 @@ tf_module {
name: "DrawBoundingBoxesV2" name: "DrawBoundingBoxesV2"
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "DummyMemoryCache"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "DynamicPartition" name: "DynamicPartition"
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "