Update ResourceTagger implementation.
PiperOrigin-RevId: 297953008 Change-Id: Ib018806c66daccc6b187083991a88cb849588cc9
This commit is contained in:
parent
9cb019e654
commit
b107574dd1
tensorflow/core
framework
kernels/data
platform
@ -571,8 +571,7 @@ void BackgroundWorker::Schedule(std::function<void()> work_item) {
|
||||
}
|
||||
|
||||
void BackgroundWorker::WorkerLoop() {
|
||||
tensorflow::ResourceTagger tag =
|
||||
tensorflow::ResourceTagger(kTFDataResourceTag, "Background");
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
|
||||
while (true) {
|
||||
std::function<void()> work_item = nullptr;
|
||||
{
|
||||
@ -609,8 +608,7 @@ namespace {
|
||||
class RunnerImpl : public Runner {
|
||||
public:
|
||||
void Run(const std::function<void()>& f) override {
|
||||
tensorflow::ResourceTagger tag =
|
||||
tensorflow::ResourceTagger(kTFDataResourceTag, "Runner");
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
|
||||
f();
|
||||
|
||||
// NOTE: We invoke a virtual function to prevent `f` being tail-called, and
|
||||
|
@ -57,8 +57,8 @@ class ToTFRecordOp : public AsyncOpKernel {
|
||||
|
||||
private:
|
||||
Status DoCompute(OpKernelContext* ctx) {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
tstring filename;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseScalarArgument<tstring>(ctx, "filename", &filename));
|
||||
|
@ -521,8 +521,8 @@ void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
IteratorResource* iterator_resource;
|
||||
@ -533,8 +533,8 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
|
||||
// The iterator resource is guaranteed to exist because the variant tensor
|
||||
// wrapping the deleter is provided as an unused input to this op, which
|
||||
@ -551,8 +551,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
@ -610,8 +610,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
OpInputList inputs;
|
||||
@ -749,8 +749,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
// running the initialization function, we must implement this
|
||||
// kernel as an async kernel.
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (iterator_resource_ == nullptr && initialization_status_.ok()) {
|
||||
@ -913,8 +913,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
",iter_num=", ctx->frame_iter().iter_id, "#");
|
||||
},
|
||||
profiler::kInfo);
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
IteratorResource* iterator;
|
||||
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
|
||||
core::ScopedUnref unref_iterator(iterator);
|
||||
@ -933,8 +933,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
IteratorResource* iterator;
|
||||
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
|
||||
core::ScopedUnref unref_iterator(iterator);
|
||||
@ -1049,8 +1049,8 @@ SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
|
||||
}
|
||||
|
||||
void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
const Tensor& resource_handle_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
@ -1074,8 +1074,8 @@ void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
|
||||
kTFDataResourceTag, ctx->op_kernel().type_string());
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag,
|
||||
ctx->op_kernel().type_string());
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an IteratorResource.
|
||||
IteratorResource* iterator_resource;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/resource.h"
|
||||
@ -70,8 +71,7 @@ std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
|
||||
|
||||
void UnboundedThreadPool::Schedule(std::function<void()> fn) {
|
||||
auto tagged_fn = [fn = std::move(fn)]() {
|
||||
tensorflow::ResourceTagger tag =
|
||||
tensorflow::ResourceTagger("tfdata", "ThreadPool");
|
||||
tensorflow::ResourceTagger tag(kTFDataResourceTag, "ThreadPool");
|
||||
fn();
|
||||
};
|
||||
ScheduleOnWorkQueue(std::move(tagged_fn), /*done=*/nullptr);
|
||||
|
@ -509,7 +509,9 @@ cc_library(
|
||||
cc_library(
|
||||
name = "resource",
|
||||
textual_hdrs = ["resource.h"],
|
||||
deps = tf_resource_deps(),
|
||||
deps = [
|
||||
":stringpiece",
|
||||
] + tf_resource_deps(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -14,11 +14,17 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/resource.h"
|
||||
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ResourceTagger::ResourceTaggerImpl {
|
||||
public:
|
||||
ResourceTaggerImpl(StringPiece key, StringPiece value) {}
|
||||
};
|
||||
|
||||
ResourceTagger::ResourceTagger(StringPiece key, StringPiece value) {}
|
||||
|
||||
ResourceTagger::~ResourceTagger() {}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,14 +16,27 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_RESOURCE_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_RESOURCE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Tracks resource usage for tagged code paths.
|
||||
// ResourceTagger objects should only be allocated on the stack.
|
||||
class ResourceTagger {
|
||||
public:
|
||||
ResourceTagger(StringPiece key, StringPiece value);
|
||||
~ResourceTagger();
|
||||
|
||||
// Do not allow copying or moving ResourceTagger
|
||||
ResourceTagger(const ResourceTagger&) = delete;
|
||||
ResourceTagger(ResourceTagger&&) = delete;
|
||||
ResourceTagger& operator=(const ResourceTagger&) = delete;
|
||||
ResourceTagger& operator=(ResourceTagger&&) = delete;
|
||||
|
||||
private:
|
||||
class ResourceTaggerImpl;
|
||||
const std::unique_ptr<ResourceTaggerImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user