[tf.data] This CL changes how the in-memory cache resource is managed, making it possible for the cache dataset to support both a) sharing of the cache across iterators and b) serialization. As a consequence, this CL enables sharing of the cache across iterators for tf.distribute and tf.data service use cases (which require serialization support).
PiperOrigin-RevId: 307469217 Change-Id: Ia4f4384752609b83cb10078ebeb20dfc6c8a2d8f
This commit is contained in:
parent
b592f87bd8
commit
7ebbab819e
|
@ -0,0 +1,4 @@
|
|||
op {
|
||||
graph_op_name: "DummyMemoryCache"
|
||||
visibility: HIDDEN
|
||||
}
|
|
@ -27,8 +27,6 @@ namespace tensorflow {
|
|||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kCacheDataset[] = "CacheDataset";
|
||||
constexpr char kCacheDatasetV2[] = "CacheDatasetV2";
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
|
@ -55,10 +53,6 @@ Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster,
|
|||
node.add_input(zero_node->name());
|
||||
// set `reshuffle_each_iteration` attr
|
||||
(*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
} else if (node.op() == kCacheDatasetV2) {
|
||||
*node.mutable_op() = kCacheDataset;
|
||||
// remove `cache` input
|
||||
node.mutable_input()->RemoveLast();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,29 +28,6 @@ namespace tensorflow {
|
|||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(MakeStateless, Cache) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_INT64}}),
|
||||
NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}),
|
||||
graph_tests_utils::MakeCacheV2Node("cache", "range", "filename",
|
||||
"handle")},
|
||||
{});
|
||||
|
||||
MakeStateless optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("cache", output));
|
||||
int index = graph_utils::FindGraphNodeWithName("cache", output);
|
||||
EXPECT_EQ(output.node(index).op(), "CacheDataset");
|
||||
EXPECT_EQ(output.node(index).input_size(), 2);
|
||||
}
|
||||
|
||||
TEST(MakeStateless, Shuffle) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
|
|
|
@ -48,7 +48,6 @@ constexpr char kShardId[] = "shard_id";
|
|||
constexpr char kCreatedAt[] = "Created at";
|
||||
constexpr char kMemoryDatasetPrefix[] = "Memory";
|
||||
constexpr char kMemoryCache[] = "MemoryCache";
|
||||
constexpr char kTFData[] = "tf_data";
|
||||
constexpr char kCacheClaimed[] = "cache_claimed";
|
||||
constexpr char kCacheSize[] = "cache_size";
|
||||
constexpr char kCache[] = "cache";
|
||||
|
@ -58,9 +57,9 @@ constexpr char kIndex[] = "index";
|
|||
constexpr char kImpl[] = "Impl";
|
||||
constexpr char kCacheDataset[] = "CacheDataset";
|
||||
|
||||
class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
public:
|
||||
explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||
string filename, Env* env)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
|
@ -76,7 +75,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
DCHECK_EQ(item_index_padding_size_, 7);
|
||||
}
|
||||
|
||||
~FileDataset() override { input_->Unref(); }
|
||||
~FileDatasetBase() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
|
@ -107,17 +106,6 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
|
||||
Node* filename = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const tstring filename_;
|
||||
|
||||
|
@ -131,10 +119,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
tensor_index);
|
||||
}
|
||||
|
||||
class FileIterator : public DatasetIterator<FileDataset> {
|
||||
class FileIterator : public DatasetIterator<FileDatasetBase> {
|
||||
public:
|
||||
explicit FileIterator(const Params& params)
|
||||
: DatasetIterator<FileDataset>(params) {
|
||||
: DatasetIterator<FileDatasetBase>(params) {
|
||||
if (params.dataset->env_
|
||||
->FileExists(MetaFilename(params.dataset->filename_))
|
||||
.ok()) {
|
||||
|
@ -199,7 +187,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
|
||||
private:
|
||||
// FileWriterIterator passes through and caches items from the input
|
||||
// FileDataset.
|
||||
// FileDatasetBase.
|
||||
//
|
||||
// This iterator is used when the cache directory is not found on disk. It
|
||||
// creates the cache directory, and passes on the underlying iterator's
|
||||
|
@ -214,10 +202,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
// partial cache gets flushed to disk in files with prefix
|
||||
// <filename>_<shard_id> where shard_id is unique for each checkpoint.
|
||||
// When all elements have been produced, these shards get coalesced.
|
||||
class FileWriterIterator : public DatasetIterator<FileDataset> {
|
||||
class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
|
||||
public:
|
||||
explicit FileWriterIterator(const Params& params)
|
||||
: DatasetIterator<FileDataset>(params),
|
||||
: DatasetIterator<FileDatasetBase>(params),
|
||||
cur_index_(0),
|
||||
shard_id_(0),
|
||||
filename_(
|
||||
|
@ -483,10 +471,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
bool iteration_completed_ TF_GUARDED_BY(mu_);
|
||||
}; // FileWriterIterator
|
||||
|
||||
class FileReaderIterator : public DatasetIterator<FileDataset> {
|
||||
class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
|
||||
public:
|
||||
explicit FileReaderIterator(const Params& params)
|
||||
: DatasetIterator<FileDataset>(params),
|
||||
: DatasetIterator<FileDatasetBase>(params),
|
||||
cur_index_(0),
|
||||
reader_(dataset()->env_, dataset()->filename_),
|
||||
iterator_restored_(false) {}
|
||||
|
@ -606,14 +594,31 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||
static constexpr size_t kMaxItems = 10000000; // 10 million
|
||||
const size_t item_index_padding_size_;
|
||||
const string tensor_format_string_;
|
||||
}; // FileDataset
|
||||
}; // FileDatasetBase
|
||||
|
||||
class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset {
|
||||
class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
|
||||
public:
|
||||
using FileDatasetBase::FileDatasetBase;
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
|
||||
Node* filename = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
|
||||
public:
|
||||
explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||
string filename, Env* env,
|
||||
const Tensor& resource_handle)
|
||||
: FileDataset(ctx, input, filename, env),
|
||||
: FileDatasetBase(ctx, input, filename, env),
|
||||
resource_handle_(resource_handle) {}
|
||||
|
||||
protected:
|
||||
|
@ -686,20 +691,15 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
|
|||
|
||||
} // namespace
|
||||
|
||||
class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
public:
|
||||
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache)
|
||||
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~MemoryDataset() override {
|
||||
input_->Unref();
|
||||
if (cache_) {
|
||||
cache_->Unref();
|
||||
}
|
||||
}
|
||||
~MemoryDatasetBase() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
|
@ -732,44 +732,13 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
||||
Node* filename_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_node, filename_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class MemoryIterator : public DatasetIterator<MemoryDataset> {
|
||||
class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
|
||||
public:
|
||||
explicit MemoryIterator(const Params& params, MemoryCache* cache)
|
||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
|
||||
|
||||
~MemoryIterator() override {
|
||||
if (dataset()->cache_ == nullptr) {
|
||||
cache_->Unref();
|
||||
}
|
||||
}
|
||||
: DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
if (cache_ == nullptr) {
|
||||
// Use the resource manager in the iterator context to get / create
|
||||
// a cache.
|
||||
ResourceMgr* mgr = ctx->resource_mgr();
|
||||
const string name = strings::StrCat(
|
||||
prefix(), name_utils::kDelimiter, dataset()->node_name(),
|
||||
name_utils::kDelimiter, kMemoryCache);
|
||||
TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>(
|
||||
kTFData, name, &cache_, [](MemoryCache** cache) {
|
||||
*cache = new MemoryCache();
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
InitializeIterator();
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
@ -817,10 +786,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||
}
|
||||
|
||||
private:
|
||||
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
|
||||
class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
|
||||
public:
|
||||
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
|
||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
|
||||
: DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
|
||||
|
||||
~MemoryWriterIterator() override {
|
||||
mutex_lock l(mu_);
|
||||
|
@ -900,12 +869,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||
std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
|
||||
}; // MemoryWriterIterator
|
||||
|
||||
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
|
||||
class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
|
||||
public:
|
||||
explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
|
||||
: DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
|
||||
CHECK(cache);
|
||||
}
|
||||
: DatasetIterator<MemoryDatasetBase>(params),
|
||||
cache_(cache),
|
||||
index_(0) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
// The memory allocated for the cache is owned by the parent
|
||||
|
@ -988,19 +957,80 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||
}; // MemoryIterator
|
||||
|
||||
const DatasetBase* const input_;
|
||||
MemoryCache* cache_ = nullptr;
|
||||
}; // MemoryDataset
|
||||
MemoryCache* const cache_;
|
||||
}; // MemoryDatasetBase
|
||||
|
||||
class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
||||
// This version of memory dataset has an exclusive ownership of the memory cache
|
||||
// resource. It supports sharing of the cache across different iterations of the
|
||||
// `repeat` transformation but not across different iterators.
|
||||
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
|
||||
public:
|
||||
explicit MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache,
|
||||
std::unique_ptr<OwnedResourceHandle> handle)
|
||||
: MemoryDataset(ctx, input, cache), handle_(std::move(handle)) {}
|
||||
MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache, const ResourceHandle& resource_handle)
|
||||
: MemoryDatasetBase(ctx, input, cache),
|
||||
resource_handle_(resource_handle) {
|
||||
cleanup_ = [this, mgr = ctx->resource_manager()]() {
|
||||
DCHECK(cache_->RefCountIsOne());
|
||||
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
|
||||
resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on memory cache resource.");
|
||||
~MemoryDataset() override {
|
||||
cache_->Unref();
|
||||
cleanup_();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
||||
Node* filename_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_node, filename_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void()> cleanup_;
|
||||
const ResourceHandle resource_handle_;
|
||||
};
|
||||
|
||||
// This version of memory dataset has a shared ownership of the memory cache
|
||||
// resource. It supports sharing of the cache across different iterations of
|
||||
// the `repeat` transformation and also across different iterators.
|
||||
class CacheDatasetOp::MemoryDatasetV2
|
||||
: public CacheDatasetOp::MemoryDatasetBase {
|
||||
public:
|
||||
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache, const ResourceHandle& resource_handle)
|
||||
: MemoryDatasetBase(ctx, input, cache),
|
||||
resource_handle_(std::move(resource_handle)) {
|
||||
cleanup_ = [this, mgr = ctx->resource_manager()]() {
|
||||
if (cache_->RefCountIsOne()) {
|
||||
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
|
||||
resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
if (errors::IsNotFound(s)) {
|
||||
// This is a bening race resulting from concurrent deletion.
|
||||
VLOG(1) << "Failed to delete cache resource: " << s.ToString();
|
||||
} else {
|
||||
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
~MemoryDatasetV2() override {
|
||||
cache_->Unref();
|
||||
cleanup_();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -1013,7 +1043,7 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
|||
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
||||
Node* resource_handle_node = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = handle_->handle();
|
||||
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_node, filename_node, resource_handle_node}, output));
|
||||
|
@ -1021,7 +1051,8 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
|||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<OwnedResourceHandle> handle_;
|
||||
std::function<void()> cleanup_;
|
||||
const ResourceHandle resource_handle_;
|
||||
};
|
||||
|
||||
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
|
||||
|
@ -1033,22 +1064,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||
// Parse out the filenames tensor.
|
||||
tstring filename;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
|
||||
|
||||
if (filename.empty()) {
|
||||
static std::atomic<int64> resource_id_counter(0);
|
||||
const string& container = ctx->resource_manager()->default_container();
|
||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
if (op_version_ == 2) {
|
||||
MemoryCache* cache = nullptr;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &cache));
|
||||
|
||||
// Create a fresh handle for the resource because the input handle can
|
||||
// become invalid after this op executes.
|
||||
std::unique_ptr<OwnedResourceHandle> handle;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, OwnedResourceHandle::Create(ctx, cache, kMemoryCache, &handle));
|
||||
|
||||
auto handle = HandleFromInput(ctx, 2);
|
||||
Status s = ctx->resource_manager()->Lookup<MemoryCache>(
|
||||
handle.container(), handle.name(), &cache);
|
||||
if (errors::IsNotFound(s)) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<MemoryCache>(
|
||||
container, name, &cache, [](MemoryCache** cache) {
|
||||
*cache = new MemoryCache();
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
// Ownership of cache is transferred onto `MemoryDatasetV2`.
|
||||
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
|
||||
} else {
|
||||
*output = new MemoryDataset(ctx, input, /*cache=*/nullptr);
|
||||
MemoryCache* cache;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
|
||||
container, name, &cache, [](MemoryCache** cache) {
|
||||
*cache = new MemoryCache();
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
|
||||
// Ownership of cache is transferred onto `MemoryDataset`.
|
||||
*output = new MemoryDataset(ctx, input, cache, handle);
|
||||
}
|
||||
} else {
|
||||
if (op_version_ == 2) {
|
||||
|
|
|
@ -22,8 +22,8 @@ namespace data {
|
|||
|
||||
class CacheDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
class FileDataset;
|
||||
class MemoryDataset;
|
||||
class FileDatasetBase;
|
||||
class MemoryDatasetBase;
|
||||
|
||||
static constexpr const char* const kDatasetType = "Cache";
|
||||
static constexpr const char* const kInputDataset = "input_dataset";
|
||||
|
@ -38,10 +38,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
|||
DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
class FileDataset;
|
||||
class FileDatasetV2;
|
||||
class MemoryDataset;
|
||||
class MemoryDatasetV2;
|
||||
|
||||
int op_version_;
|
||||
const int op_version_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
|
@ -26,7 +27,7 @@ namespace tensorflow {
|
|||
namespace data {
|
||||
namespace {
|
||||
|
||||
const char kMemoryCache[] = "MemoryCache";
|
||||
constexpr char kMemoryCache[] = "MemoryCache";
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -82,10 +83,11 @@ Status AnonymousMemoryCacheHandleOp::CreateResource(
|
|||
|
||||
void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) {
|
||||
const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
|
||||
// The resource is guaranteed to exist because the variant tensor wrapping the
|
||||
// deleter is provided as an unused input to this op, which guarantees that it
|
||||
// has not run yet.
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete(handle));
|
||||
// The resource might have been already deleted by the dataset.
|
||||
Status s = ctx->resource_manager()->Delete(handle);
|
||||
if (!errors::IsNotFound(s)) {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -96,6 +98,9 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousMemoryCache").Device(DEVICE_CPU),
|
|||
REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU),
|
||||
DeleteMemoryCacheOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DummyMemoryCache").Device(DEVICE_CPU),
|
||||
DummyResourceOp<MemoryCache>);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -286,6 +286,28 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
|
|||
std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
|
||||
std::function<void(std::function<void()>)> runner, int max_parallelism);
|
||||
|
||||
// Op for creating a typed dummy resource.
|
||||
//
|
||||
// This op is used to provide a resource "placeholder" for ops such as
|
||||
// `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
|
||||
// Originally, the lifetime of the resources passed into these ops was managed
|
||||
// externally. After the implementation changed to manage the lifetime of the
|
||||
// resources (including creation) by the ops themselves, the resource input is
|
||||
// only needed to pass a resource handle through graph rewrites. When they are
|
||||
// invoked from user code, the implementation passes in a dummy resource.
|
||||
template <typename ResourceType>
|
||||
class DummyResourceOp : public OpKernel {
|
||||
public:
|
||||
explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor* tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
|
||||
tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
|
||||
ctx, /*container=*/"", /*name=*/"dummy_resource");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -504,6 +504,13 @@ REGISTER_OP("DeleteMemoryCache")
|
|||
.Input("deleter: variant")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("DummyMemoryCache")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CacheDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("filename: string")
|
||||
|
|
|
@ -244,8 +244,6 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
|
||||
cache_dataset = repeat_dataset.cache()
|
||||
|
||||
# Create initialization ops for iterators without and with
|
||||
# caching, respectively.
|
||||
self.assertDatasetProduces(cache_dataset, expected_output=[])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
|
|
|
@ -3526,6 +3526,8 @@ class RangeDataset(DatasetSource):
|
|||
return self._structure
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/20.
|
||||
class _MemoryCacheDeleter(object):
|
||||
"""An object which cleans up an anonymous memory cache resource.
|
||||
|
||||
|
@ -3552,13 +3554,18 @@ class _MemoryCacheDeleter(object):
|
|||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
# This can be deleted after the forward compatibility window for switching
|
||||
# to using dummy resource expires on 5/20.
|
||||
class _MemoryCache(object):
|
||||
"""Represents a memory cache resource."""
|
||||
|
||||
def __init__(self):
|
||||
super(_MemoryCache, self).__init__()
|
||||
if compat.forward_compatible(2020, 5, 20):
|
||||
self._handle = gen_dataset_ops.dummy_memory_cache()
|
||||
else:
|
||||
self._device = context.context().device_name
|
||||
self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache())
|
||||
self._handle, self._deleter = gen_dataset_ops.anonymous_memory_cache()
|
||||
self._resource_deleter = _MemoryCacheDeleter(
|
||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||
|
||||
|
|
|
@ -1176,6 +1176,10 @@ tf_module {
|
|||
name: "DrawBoundingBoxesV2"
|
||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummyMemoryCache"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DynamicPartition"
|
||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
@ -1176,6 +1176,10 @@ tf_module {
|
|||
name: "DrawBoundingBoxesV2"
|
||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DummyMemoryCache"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DynamicPartition"
|
||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
Loading…
Reference in New Issue