diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index dae5f61f1cd..97c4d212223 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -571,8 +571,7 @@ void BackgroundWorker::Schedule(std::function work_item) { } void BackgroundWorker::WorkerLoop() { - tensorflow::ResourceTagger tag = - tensorflow::ResourceTagger(kTFDataResourceTag, "Background"); + tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background"); while (true) { std::function work_item = nullptr; { @@ -609,8 +608,7 @@ namespace { class RunnerImpl : public Runner { public: void Run(const std::function& f) override { - tensorflow::ResourceTagger tag = - tensorflow::ResourceTagger(kTFDataResourceTag, "Runner"); + tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner"); f(); // NOTE: We invoke a virtual function to prevent `f` being tail-called, and diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index 6a910145b53..bfa894cd473 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -57,8 +57,8 @@ class ToTFRecordOp : public AsyncOpKernel { private: Status DoCompute(OpKernelContext* ctx) { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); tstring filename; TF_RETURN_IF_ERROR( ParseScalarArgument(ctx, "filename", &filename)); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 4adf7f64fba..c3b365ead44 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -521,8 +521,8 @@ void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) { } Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); IteratorResource* iterator_resource; @@ -533,8 +533,8 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) { } void DeleteIteratorOp::Compute(OpKernelContext* ctx) { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); ResourceHandle handle = ctx->input(0).flat()(0); // The iterator resource is guaranteed to exist because the variant tensor // wrapping the deleter is provided as an unused input to this op, which @@ -551,8 +551,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel { protected: Status DoCompute(OpKernelContext* ctx) override { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); @@ -610,8 +610,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel { protected: Status DoCompute(OpKernelContext* ctx) override { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); OpInputList inputs; @@ -749,8 +749,8 @@ class OneShotIteratorOp : public AsyncOpKernel { // running the initialization function, we must implement this // kernel as an async kernel. void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); { mutex_lock l(mu_); if (iterator_resource_ == nullptr && initialization_status_.ok()) { @@ -913,8 +913,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { ",iter_num=", ctx->frame_iter().iter_id, "#"); }, profiler::kInfo); - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); IteratorResource* iterator; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); core::ScopedUnref unref_iterator(iterator); @@ -933,8 +933,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { } Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); IteratorResource* iterator; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); core::ScopedUnref unref_iterator(iterator); @@ -1049,8 +1049,8 @@ SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx) } void SerializeIteratorOp::Compute(OpKernelContext* ctx) { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); const Tensor& resource_handle_t = ctx->input(0); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), errors::InvalidArgument("resource_handle must be a scalar")); @@ -1074,8 +1074,8 @@ void SerializeIteratorOp::Compute(OpKernelContext* ctx) { } void DeserializeIteratorOp::Compute(OpKernelContext* ctx) { - tensorflow::ResourceTagger tag = tensorflow::ResourceTagger( - kTFDataResourceTag, ctx->op_kernel().type_string()); + tensorflow::ResourceTagger tag(kTFDataResourceTag, + ctx->op_kernel().type_string()); // Validate that the handle corresponds to a real resource, and // that it is an IteratorResource. IteratorResource* iterator_resource; diff --git a/tensorflow/core/kernels/data/unbounded_thread_pool.cc b/tensorflow/core/kernels/data/unbounded_thread_pool.cc index 0f0baa83f98..b0b1023dfba 100644 --- a/tensorflow/core/kernels/data/unbounded_thread_pool.cc +++ b/tensorflow/core/kernels/data/unbounded_thread_pool.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/unbounded_thread_pool.h" #include "absl/memory/memory.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource.h" @@ -70,8 +71,7 @@ std::shared_ptr UnboundedThreadPool::get_thread_factory() { void UnboundedThreadPool::Schedule(std::function fn) { auto tagged_fn = [fn = std::move(fn)]() { - tensorflow::ResourceTagger tag = - tensorflow::ResourceTagger("tfdata", "ThreadPool"); + tensorflow::ResourceTagger tag(kTFDataResourceTag, "ThreadPool"); fn(); }; ScheduleOnWorkQueue(std::move(tagged_fn), /*done=*/nullptr); diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index eef097b0c75..63f8974506a 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -509,7 +509,9 @@ cc_library( cc_library( name = "resource", textual_hdrs = ["resource.h"], - deps = tf_resource_deps(), + deps = [ + ":stringpiece", + ] + tf_resource_deps(), ) cc_library( diff --git a/tensorflow/core/platform/default/resource.cc b/tensorflow/core/platform/default/resource.cc index 2ff257626f9..1fd1237844e 100644 --- a/tensorflow/core/platform/default/resource.cc +++ b/tensorflow/core/platform/default/resource.cc @@ -14,11 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/resource.h" - #include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { +class ResourceTagger::ResourceTaggerImpl { + public: + ResourceTaggerImpl(StringPiece key, StringPiece value) {} +}; + ResourceTagger::ResourceTagger(StringPiece key, StringPiece value) {} +ResourceTagger::~ResourceTagger() {} + } // namespace tensorflow diff --git a/tensorflow/core/platform/resource.h b/tensorflow/core/platform/resource.h index b319f72ce76..6308f88039a 100644 --- a/tensorflow/core/platform/resource.h +++ b/tensorflow/core/platform/resource.h @@ -16,14 +16,27 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_RESOURCE_H_ #define TENSORFLOW_CORE_PLATFORM_RESOURCE_H_ +#include + #include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { -// Tracks resource usage for tagged code paths. +// ResourceTagger objects should only be allocated on the stack. class ResourceTagger { public: ResourceTagger(StringPiece key, StringPiece value); + ~ResourceTagger(); + + // Do not allow copying or moving ResourceTagger + ResourceTagger(const ResourceTagger&) = delete; + ResourceTagger(ResourceTagger&&) = delete; + ResourceTagger& operator=(const ResourceTagger&) = delete; + ResourceTagger& operator=(ResourceTagger&&) = delete; + + private: + class ResourceTaggerImpl; + const std::unique_ptr impl_; }; } // namespace tensorflow