Update ResourceTagger implementation.
PiperOrigin-RevId: 297953008 Change-Id: Ib018806c66daccc6b187083991a88cb849588cc9
This commit is contained in:
		
							parent
							
								
									9cb019e654
								
							
						
					
					
						commit
						b107574dd1
					
				@ -571,8 +571,7 @@ void BackgroundWorker::Schedule(std::function<void()> work_item) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BackgroundWorker::WorkerLoop() {
 | 
			
		||||
  tensorflow::ResourceTagger tag =
 | 
			
		||||
      tensorflow::ResourceTagger(kTFDataResourceTag, "Background");
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
 | 
			
		||||
  while (true) {
 | 
			
		||||
    std::function<void()> work_item = nullptr;
 | 
			
		||||
    {
 | 
			
		||||
@ -609,8 +608,7 @@ namespace {
 | 
			
		||||
class RunnerImpl : public Runner {
 | 
			
		||||
 public:
 | 
			
		||||
  void Run(const std::function<void()>& f) override {
 | 
			
		||||
    tensorflow::ResourceTagger tag =
 | 
			
		||||
        tensorflow::ResourceTagger(kTFDataResourceTag, "Runner");
 | 
			
		||||
    tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
 | 
			
		||||
    f();
 | 
			
		||||
 | 
			
		||||
    // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
 | 
			
		||||
 | 
			
		||||
@ -57,8 +57,8 @@ class ToTFRecordOp : public AsyncOpKernel {
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  Status DoCompute(OpKernelContext* ctx) {
 | 
			
		||||
    tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
        kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
    tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                   ctx->op_kernel().type_string());
 | 
			
		||||
    tstring filename;
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        ParseScalarArgument<tstring>(ctx, "filename", &filename));
 | 
			
		||||
 | 
			
		||||
@ -521,8 +521,8 @@ void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
 | 
			
		||||
  tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
      kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                 ctx->op_kernel().type_string());
 | 
			
		||||
  DatasetBase* dataset;
 | 
			
		||||
  TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
 | 
			
		||||
  IteratorResource* iterator_resource;
 | 
			
		||||
@ -533,8 +533,8 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
 | 
			
		||||
  tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
      kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                 ctx->op_kernel().type_string());
 | 
			
		||||
  ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
 | 
			
		||||
  // The iterator resource is guaranteed to exist because the variant tensor
 | 
			
		||||
  // wrapping the deleter is provided as an unused input to this op, which
 | 
			
		||||
@ -551,8 +551,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  Status DoCompute(OpKernelContext* ctx) override {
 | 
			
		||||
    tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
        kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
    tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                   ctx->op_kernel().type_string());
 | 
			
		||||
    DatasetBase* dataset;
 | 
			
		||||
    TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
 | 
			
		||||
 | 
			
		||||
@ -610,8 +610,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  Status DoCompute(OpKernelContext* ctx) override {
 | 
			
		||||
    tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
        kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
    tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                   ctx->op_kernel().type_string());
 | 
			
		||||
    DatasetBase* dataset;
 | 
			
		||||
    TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
 | 
			
		||||
    OpInputList inputs;
 | 
			
		||||
@ -749,8 +749,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
 | 
			
		||||
  // running the initialization function, we must implement this
 | 
			
		||||
  // kernel as an async kernel.
 | 
			
		||||
  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
 | 
			
		||||
    tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
        kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
    tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                   ctx->op_kernel().type_string());
 | 
			
		||||
    {
 | 
			
		||||
      mutex_lock l(mu_);
 | 
			
		||||
      if (iterator_resource_ == nullptr && initialization_status_.ok()) {
 | 
			
		||||
@ -913,8 +913,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
 | 
			
		||||
            ",iter_num=", ctx->frame_iter().iter_id, "#");
 | 
			
		||||
      },
 | 
			
		||||
      profiler::kInfo);
 | 
			
		||||
  tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
      kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                 ctx->op_kernel().type_string());
 | 
			
		||||
  IteratorResource* iterator;
 | 
			
		||||
  TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
 | 
			
		||||
  core::ScopedUnref unref_iterator(iterator);
 | 
			
		||||
@ -933,8 +933,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
 | 
			
		||||
  tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
      kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                 ctx->op_kernel().type_string());
 | 
			
		||||
  IteratorResource* iterator;
 | 
			
		||||
  TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
 | 
			
		||||
  core::ScopedUnref unref_iterator(iterator);
 | 
			
		||||
@ -1049,8 +1049,8 @@ SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
 | 
			
		||||
  tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
      kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                 ctx->op_kernel().type_string());
 | 
			
		||||
  const Tensor& resource_handle_t = ctx->input(0);
 | 
			
		||||
  OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
 | 
			
		||||
              errors::InvalidArgument("resource_handle must be a scalar"));
 | 
			
		||||
@ -1074,8 +1074,8 @@ void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
 | 
			
		||||
  tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
 | 
			
		||||
      kTFDataResourceTag, ctx->op_kernel().type_string());
 | 
			
		||||
  tensorflow::ResourceTagger tag(kTFDataResourceTag,
 | 
			
		||||
                                 ctx->op_kernel().type_string());
 | 
			
		||||
  // Validate that the handle corresponds to a real resource, and
 | 
			
		||||
  // that it is an IteratorResource.
 | 
			
		||||
  IteratorResource* iterator_resource;
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
 | 
			
		||||
 | 
			
		||||
#include "absl/memory/memory.h"
 | 
			
		||||
#include "tensorflow/core/framework/dataset.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/notification.h"
 | 
			
		||||
#include "tensorflow/core/platform/env.h"
 | 
			
		||||
#include "tensorflow/core/platform/resource.h"
 | 
			
		||||
@ -70,8 +71,7 @@ std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
 | 
			
		||||
 | 
			
		||||
void UnboundedThreadPool::Schedule(std::function<void()> fn) {
 | 
			
		||||
  auto tagged_fn = [fn = std::move(fn)]() {
 | 
			
		||||
    tensorflow::ResourceTagger tag =
 | 
			
		||||
        tensorflow::ResourceTagger("tfdata", "ThreadPool");
 | 
			
		||||
    tensorflow::ResourceTagger tag(kTFDataResourceTag, "ThreadPool");
 | 
			
		||||
    fn();
 | 
			
		||||
  };
 | 
			
		||||
  ScheduleOnWorkQueue(std::move(tagged_fn), /*done=*/nullptr);
 | 
			
		||||
 | 
			
		||||
@ -509,7 +509,9 @@ cc_library(
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "resource",
 | 
			
		||||
    textual_hdrs = ["resource.h"],
 | 
			
		||||
    deps = tf_resource_deps(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":stringpiece",
 | 
			
		||||
    ] + tf_resource_deps(),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
 | 
			
		||||
@ -14,11 +14,17 @@ limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/platform/resource.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/platform/stringpiece.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class ResourceTagger::ResourceTaggerImpl {
 | 
			
		||||
 public:
 | 
			
		||||
  ResourceTaggerImpl(StringPiece key, StringPiece value) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResourceTagger::ResourceTagger(StringPiece key, StringPiece value) {}
 | 
			
		||||
 | 
			
		||||
ResourceTagger::~ResourceTagger() {}
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -16,14 +16,27 @@ limitations under the License.
 | 
			
		||||
#ifndef TENSORFLOW_CORE_PLATFORM_RESOURCE_H_
 | 
			
		||||
#define TENSORFLOW_CORE_PLATFORM_RESOURCE_H_
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/platform/stringpiece.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
// Tracks resource usage for tagged code paths.
 | 
			
		||||
// ResourceTagger objects should only be allocated on the stack.
 | 
			
		||||
class ResourceTagger {
 | 
			
		||||
 public:
 | 
			
		||||
  ResourceTagger(StringPiece key, StringPiece value);
 | 
			
		||||
  ~ResourceTagger();
 | 
			
		||||
 | 
			
		||||
  // Do not allow copying or moving ResourceTagger
 | 
			
		||||
  ResourceTagger(const ResourceTagger&) = delete;
 | 
			
		||||
  ResourceTagger(ResourceTagger&&) = delete;
 | 
			
		||||
  ResourceTagger& operator=(const ResourceTagger&) = delete;
 | 
			
		||||
  ResourceTagger& operator=(ResourceTagger&&) = delete;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  class ResourceTaggerImpl;
 | 
			
		||||
  const std::unique_ptr<ResourceTaggerImpl> impl_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user