Update ResourceTagger implementation.

PiperOrigin-RevId: 297953008
Change-Id: Ib018806c66daccc6b187083991a88cb849588cc9
This commit is contained in:
A. Unique TensorFlower 2020-02-28 15:50:07 -08:00 committed by TensorFlower Gardener
parent 9cb019e654
commit b107574dd1
7 changed files with 48 additions and 29 deletions

View File

@ -571,8 +571,7 @@ void BackgroundWorker::Schedule(std::function<void()> work_item) {
}
void BackgroundWorker::WorkerLoop() {
tensorflow::ResourceTagger tag =
tensorflow::ResourceTagger(kTFDataResourceTag, "Background");
tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
while (true) {
std::function<void()> work_item = nullptr;
{
@ -609,8 +608,7 @@ namespace {
class RunnerImpl : public Runner {
public:
void Run(const std::function<void()>& f) override {
tensorflow::ResourceTagger tag =
tensorflow::ResourceTagger(kTFDataResourceTag, "Runner");
tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
f();
// NOTE: We invoke a virtual function to prevent `f` being tail-called, and

View File

@ -57,8 +57,8 @@ class ToTFRecordOp : public AsyncOpKernel {
private:
Status DoCompute(OpKernelContext* ctx) {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
tstring filename;
TF_RETURN_IF_ERROR(
ParseScalarArgument<tstring>(ctx, "filename", &filename));

View File

@ -521,8 +521,8 @@ void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
}
Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
IteratorResource* iterator_resource;
@ -533,8 +533,8 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
}
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
// The iterator resource is guaranteed to exist because the variant tensor
// wrapping the deleter is provided as an unused input to this op, which
@ -551,8 +551,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
protected:
Status DoCompute(OpKernelContext* ctx) override {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
@ -610,8 +610,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
protected:
Status DoCompute(OpKernelContext* ctx) override {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
OpInputList inputs;
@ -749,8 +749,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
// running the initialization function, we must implement this
// kernel as an async kernel.
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
{
mutex_lock l(mu_);
if (iterator_resource_ == nullptr && initialization_status_.ok()) {
@ -913,8 +913,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
",iter_num=", ctx->frame_iter().iter_id, "#");
},
profiler::kInfo);
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
IteratorResource* iterator;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
core::ScopedUnref unref_iterator(iterator);
@ -933,8 +933,8 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
}
Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
IteratorResource* iterator;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
core::ScopedUnref unref_iterator(iterator);
@ -1049,8 +1049,8 @@ SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
}
void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
@ -1074,8 +1074,8 @@ void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
}
void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
tensorflow::ResourceTagger tag = tensorflow::ResourceTagger(
kTFDataResourceTag, ctx->op_kernel().type_string());
tensorflow::ResourceTagger tag(kTFDataResourceTag,
ctx->op_kernel().type_string());
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource.h"
@ -70,8 +71,7 @@ std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
void UnboundedThreadPool::Schedule(std::function<void()> fn) {
auto tagged_fn = [fn = std::move(fn)]() {
tensorflow::ResourceTagger tag =
tensorflow::ResourceTagger("tfdata", "ThreadPool");
tensorflow::ResourceTagger tag(kTFDataResourceTag, "ThreadPool");
fn();
};
ScheduleOnWorkQueue(std::move(tagged_fn), /*done=*/nullptr);

View File

@ -509,7 +509,9 @@ cc_library(
cc_library(
name = "resource",
textual_hdrs = ["resource.h"],
deps = tf_resource_deps(),
deps = [
":stringpiece",
] + tf_resource_deps(),
)
cc_library(

View File

@ -14,11 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/resource.h"
#include "tensorflow/core/platform/stringpiece.h"
namespace tensorflow {
class ResourceTagger::ResourceTaggerImpl {
public:
ResourceTaggerImpl(StringPiece key, StringPiece value) {}
};
ResourceTagger::ResourceTagger(StringPiece key, StringPiece value) {}
ResourceTagger::~ResourceTagger() {}
} // namespace tensorflow

View File

@ -16,14 +16,27 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_RESOURCE_H_
#define TENSORFLOW_CORE_PLATFORM_RESOURCE_H_
#include <memory>
#include "tensorflow/core/platform/stringpiece.h"
namespace tensorflow {
// Tracks resource usage for tagged code paths.
// ResourceTagger objects should only be allocated on the stack.
class ResourceTagger {
public:
ResourceTagger(StringPiece key, StringPiece value);
~ResourceTagger();
// Do not allow copying or moving ResourceTagger
ResourceTagger(const ResourceTagger&) = delete;
ResourceTagger(ResourceTagger&&) = delete;
ResourceTagger& operator=(const ResourceTagger&) = delete;
ResourceTagger& operator=(ResourceTagger&&) = delete;
private:
class ResourceTaggerImpl;
const std::unique_ptr<ResourceTaggerImpl> impl_;
};
} // namespace tensorflow