diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9a7f75b18a2..730fc36202c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -127,9 +127,9 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_id, [&status, device](const string& name) { status = device->resource_manager()->Cleanup(name); }); - TF_RETURN_IF_ERROR(device->resource_manager()->Create( - step_container->name(), XlaContext::kXlaContextResourceName, - xla_context)); + TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(), + XlaContext::kXlaContextResourceName, + xla_context)); GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); TF_RETURN_IF_ERROR(graph_compiler.Compile()); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 3f787fd86c9..e49c944eeb3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -45,8 +45,8 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context"; // per-step context is looked up in the resource manager. The // JIT will prepopulate the JITContext. XlaContext* context; - TF_CHECK_OK(ctx->resource_manager()->Lookup( - ctx->step_container()->name(), kXlaContextResourceName, &context)); + TF_CHECK_OK(ctx->step_container()->Lookup(ctx->resource_manager(), + kXlaContextResourceName, &context)); // The resource manager handed us a fresh reference to 'context', but retains // a reference itself so the context won't be freed. The resource manager will // outlive the JIT compilation. diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 228a236edc2..f9cddab6883 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/profiler/lib/scoped_annotation.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/public/version.h" @@ -207,25 +206,20 @@ Status KernelAndDeviceOp::Run( const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, const absl::optional& remote_func_params) { - ScopedStepContainer step_container(0, [this](const string& name) { - device_->resource_manager()->Cleanup(name).IgnoreError(); - }); - return this->Run(&step_container, inputs, outputs, cancellation_manager, - remote_func_params); + Status s = this->Run(&step_container_, inputs, outputs, cancellation_manager, + remote_func_params); + step_container_.CleanUp(); + return s; } Status KernelAndDeviceFunc::Run( const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, const absl::optional& remote_func_params) { - const std::vector devices = pflr_->device_mgr()->ListDevices(); - ScopedStepContainer step_container(0, [&devices](const string& name) { - for (Device* device : devices) { - device->resource_manager()->Cleanup(name).IgnoreError(); - } - }); - return this->Run(&step_container, inputs, outputs, cancellation_manager, - remote_func_params); + Status s = this->Run(&step_container_, inputs, outputs, cancellation_manager, + remote_func_params); + step_container_.CleanUp(); + return s; } namespace { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index a622d52c962..395dcc98f78 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -175,7 +175,10 @@ class KernelAndDeviceOp final : public KernelAndDevice { host_cpu_device), rendez_(rendez), log_memory_(log_memory), - compile_with_xla_(compile_with_xla) {} + compile_with_xla_(compile_with_xla), + step_container_(0, [this](const string& name) { + device_->resource_manager()->Cleanup(name).IgnoreError(); + }) {} ~KernelAndDeviceOp() override {} @@ -212,6 +215,8 @@ class KernelAndDeviceOp final : public KernelAndDevice { checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; const bool log_memory_; const bool compile_with_xla_; + + ScopedStepContainer step_container_; }; // Represents a multi-device function. Functions can also be run using @@ -241,9 +246,17 @@ class KernelAndDeviceFunc final : public KernelAndDevice { std::move(input_resource_dtypes_and_shapes)), name_(name), rendezvous_creator_(std::move(rendezvous_creator)), - get_op_id_(std::move(get_op_id)) {} + get_op_id_(std::move(get_op_id)), + step_container_(0, [this](const string& name) { + // TODO(b/139809335): This does not properly clean up remote resources + const std::vector devices = + pflr_->device_mgr()->ListDevices(); + for (Device* device : devices) { + device->resource_manager()->Cleanup(name).IgnoreError(); + } + }) {} - virtual ~KernelAndDeviceFunc(); + ~KernelAndDeviceFunc() override; bool IsFunction() override { return true; }; @@ -295,6 +308,8 @@ class KernelAndDeviceFunc final : public KernelAndDevice { std::function rendezvous_creator_; std::function get_op_id_; + + ScopedStepContainer step_container_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index bf3b769f0c8..34ef6e694d3 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -33,18 +33,13 @@ namespace tensorflow { static std::atomic current_id_; ResourceHandle MakeResourceHandle( - OpKernelContext* ctx, const string& container, const string& name, + const string& container, const string& name, const DeviceBase& device, const TypeIndex& type_index, const std::vector& dtypes_and_shapes) { ResourceHandle result; - result.set_device(ctx->device()->attributes().name()); + result.set_device(device.name()); string actual_container; - if (!container.empty()) { - actual_container = container; - } else { - actual_container = ctx->resource_manager()->default_container(); - } - result.set_container(actual_container); + result.set_container(container); if (name == ResourceHandle::ANONYMOUS_NAME) { result.set_name(strings::StrCat("_AnonymousVar", current_id_.fetch_add(1))); } else { @@ -63,7 +58,7 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, TF_RETURN_IF_ERROR( context->allocate_output(output_index, TensorShape({}), &handle)); handle->scalar()() = - MakeResourceHandle(context, container, name, type_index); + MakeResourceHandle(container, name, *context->device(), type_index); return Status::OK(); } diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 1fc1165c069..fa52e4881fc 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -93,26 +93,62 @@ class ScopedStepContainer { // prefix: optional string prefix to disambiguate step containers. ScopedStepContainer(const int64 step_id, std::function cleanup) - : name_(strings::StrCat("__per_step_", step_id)), + : container_(strings::StrCat("__per_step_", step_id)), step_id_(step_id), - cleanup_(cleanup) {} + cleanup_(cleanup), + dirty_(false) {} ScopedStepContainer(const int64 step_id, std::function cleanup, const string& prefix) - : name_(strings::StrCat("__", prefix, "_per_step_", step_id)), + : container_(strings::StrCat("__", prefix, "_per_step_", step_id)), step_id_(step_id), - cleanup_(cleanup) {} + cleanup_(cleanup), + dirty_(false) {} - ~ScopedStepContainer() { cleanup_(name_); } + ~ScopedStepContainer() { CleanUp(); } + + void CleanUp() { + mutex_lock ml(mu_); + if (dirty_) { + cleanup_(container_); + dirty_ = false; + } + } + + // Pass through functions for resource lookup and creation. We do this to + // ensure that we can appropriately set the dirty_ bit in the + // ScopedStepContainer if the name of the container is used to create + // resources. + + // Pass through to MakeResourceHandle with the container name + template + ResourceHandle MakeResourceHandle( + const string& name, const DeviceBase& device) TF_MUST_USE_RESULT; + // Pass through to ResourceMgr::Create with the container name + template + Status Create(ResourceMgr* rm, const string& name, + T* resource) TF_MUST_USE_RESULT; + // Pass through to ResourceMgr::Delete with the container name + template + Status Delete(ResourceMgr* rm, const string& name) TF_MUST_USE_RESULT; + // Pass through to ResourceMgr::Lookup with the container name + template + Status Lookup(ResourceMgr* rm, const string& name, + T** resource) const TF_MUST_USE_RESULT; + // Pass through to ResourceMgr::LookupOrCreate with the container name + template + Status LookupOrCreate(ResourceMgr* rm, const string& name, T** resource, + std::function creator) TF_MUST_USE_RESULT; - const string& name() const { return name_; } const int64 step_id() const { return step_id_; } private: - const string name_; + const string container_; const int64 step_id_; const std::function cleanup_; + mutex mu_; + mutable bool dirty_ GUARDED_BY(mu_); }; class ResourceMgr { @@ -255,26 +291,25 @@ class ResourceMgr { // Makes a resource handle with the specified type for a given container / // name. ResourceHandle MakeResourceHandle( - OpKernelContext* ctx, const string& container, const string& name, + const string& container, const string& name, const DeviceBase& device, const TypeIndex& type_index, - const std::vector& dtypes_and_shapes = {}); + const std::vector& dtypes_and_shapes = {}) + TF_MUST_USE_RESULT; template ResourceHandle MakeResourceHandle( OpKernelContext* ctx, const string& container, const string& name, const std::vector& dtypes_and_shapes = {}) { - return MakeResourceHandle(ctx, container, name, MakeTypeIndex(), - dtypes_and_shapes); + return MakeResourceHandle( + container.empty() ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), MakeTypeIndex(), dtypes_and_shapes); } Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, const string& container, const string& name, const TypeIndex& type_index); -template -ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, - const string& name); - // Returns a resource handle from a numbered op input. const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); Status HandleFromInput(OpKernelContext* ctx, StringPiece input, @@ -660,12 +695,6 @@ Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, return ctx->resource_manager()->Lookup(container, shared_name, resource); } -template -ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, - const string& name) { - return MakeResourceHandle(ctx, ctx->step_container()->name(), name); -} - namespace internal { Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); @@ -840,6 +869,43 @@ void ResourceHandlesOp::Compute(OpKernelContext* ctx) { } } +template +ResourceHandle ScopedStepContainer::MakeResourceHandle( + const string& name, const DeviceBase& device) { + mutex_lock ml(mu_); + dirty_ = true; + return tensorflow::MakeResourceHandle(container_, name, device, + MakeTypeIndex(), {}); +} + +template +Status ScopedStepContainer::Lookup(ResourceMgr* rm, const string& name, + T** resource) const { + return rm->Lookup(container_, name, resource); +} + +template +Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name, + T** resource, + std::function creator) { + mutex_lock ml(mu_); + dirty_ = true; + return rm->LookupOrCreate(container_, name, resource, creator); +} + +template +Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name, + T* resource) { + mutex_lock ml(mu_); + dirty_ = true; + return rm->Create(container_, name, resource); +} + +template +Status ScopedStepContainer::Delete(ResourceMgr* rm, const string& name) { + return rm->Delete(container_, name); +} + } // end namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index e093b144c85..84ecd79efdf 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -238,6 +238,7 @@ class StubDevice : public DeviceBase { } const DeviceAttributes& attributes() const override { return attr_; } + const string& name() const override { return attr_.name(); } private: DeviceAttributes attr_; diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 02931024dca..a2e7bd81b37 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -34,8 +34,8 @@ Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceMgr* mgr = ctx->resource_manager(); TF_RETURN_IF_ERROR(mgr->Create(container_name, unique_name, resource)); - *handle = - MakeResourceHandle(ctx, container_name, unique_name, MakeTypeIndex()); + *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(), + MakeTypeIndex()); return Status::OK(); } diff --git a/tensorflow/core/kernels/stack.cc b/tensorflow/core/kernels/stack.cc index af8f760d47f..f6d37edc896 100644 --- a/tensorflow/core/kernels/stack.cc +++ b/tensorflow/core/kernels/stack.cc @@ -145,7 +145,7 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) { if (step_container == nullptr) { return errors::Internal("No step container."); } - TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); + TF_RETURN_IF_ERROR(step_container->Lookup(rm, key, stack)); return Status::OK(); } } @@ -188,7 +188,7 @@ void StackOp::Compute(OpKernelContext* ctx) { OP_REQUIRES(ctx, step_container != nullptr, errors::Internal("No step container.")); Stack* stack = new Stack(elem_type_, stack_name, size); - OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack)); + OP_REQUIRES_OK(ctx, step_container->Create(rm, key, stack)); if (IsRefType(ctx->expected_output_dtype(0))) { // Create the stack handle. AllocatorAttributes alloc_attr; @@ -204,7 +204,7 @@ void StackOp::Compute(OpKernelContext* ctx) { Tensor* handle; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); handle->flat()(0) = - MakePerStepResourceHandle(ctx, key); + ctx->step_container()->MakeResourceHandle(key, *ctx->device()); } } diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index bea97d1a1f1..e41b15016b6 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -46,7 +46,7 @@ Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current, return errors::InvalidArgument( "tensor_array::AddToTensor type not supported: ", DataTypeString(DataTypeToEnum::value)); -}; +} #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \ template <> \ @@ -74,7 +74,7 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) { return errors::InvalidArgument( "tensor_array::TensorSetZero type not supported: ", DataTypeString(DataTypeToEnum::value)); -}; +} #define TENSOR_ARRAY_SET_ZERO(Device, T) \ template <> \ @@ -347,7 +347,8 @@ class TensorArray : public ResourceBase { Tensor* handle() { return &handle_; } ResourceHandle resource_handle(OpKernelContext* ctx) { - return MakePerStepResourceHandle(ctx, key_); + return ctx->step_container()->MakeResourceHandle( + key_, *ctx->device()); } private: diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 52162e94650..62d03f9fb7f 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -79,8 +79,8 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); ResourceMgr* rm = ctx->resource_manager(); if (rm == nullptr) return errors::Internal("No resource manager."); - TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), - container + ta_handle, tensor_array)); + TF_RETURN_IF_ERROR( + ctx->step_container()->Lookup(rm, container + ta_handle, tensor_array)); return Status::OK(); } else { return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array); @@ -209,8 +209,7 @@ class TensorArrayOp : public TensorArrayCreationOp { false /* multiple_writes_aggregate */, false /* is_grad */, -1 /* marked_size */, clear_after_read_); - TF_RETURN_IF_ERROR( - rm->Create(ctx->step_container()->name(), key, tensor_array)); + TF_RETURN_IF_ERROR(ctx->step_container()->Create(rm, key, tensor_array)); *output_tensor_array = tensor_array; @@ -306,9 +305,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp { output_handle(1) = strings::StrCat(tensor_array_name, "@", source_); TensorArray* tensor_array; - TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), - strings::StrCat(container, tensor_array_name), - &tensor_array)); + TF_RETURN_IF_ERROR(ctx->step_container()->Lookup( + rm, strings::StrCat(container, tensor_array_name), &tensor_array)); core::ScopedUnref unref(tensor_array); // Once gradients are being calculated, the forward TensorArray @@ -364,8 +362,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp { return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend); }; - Status s = rm->LookupOrCreate( - ctx->step_container()->name(), key, output_tensor_array, creator); + Status s = ctx->step_container()->LookupOrCreate( + rm, key, output_tensor_array, creator); (*output_tensor_array)->Unref(); return s; diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 3865bdbb848..b023a506d86 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -100,8 +100,8 @@ class TemporaryVariableOp : public OpKernel { s = context->allocate_temp(dtype_, shape_, &tmp_var->val); if (!s.ok()) tmp_var->Unref(); OP_REQUIRES_OK(context, s); - OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(), - var_name_, tmp_var)); + OP_REQUIRES_OK(context, + context->step_container()->Create(rm, var_name_, tmp_var)); context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); if (context->track_allocations()) { context->record_persistent_memory_allocation( @@ -145,8 +145,9 @@ class DestroyTemporaryVariableOp : public OpKernel { context->set_output(0, tmpvar); ResourceMgr* rm = context->resource_manager(); OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); - OP_REQUIRES_OK(context, rm->Delete( - context->step_container()->name(), var_name_)); + OP_REQUIRES_OK( + context, context->step_container()->Delete( + rm, var_name_)); if (context->track_allocations()) { context->record_persistent_memory_allocation( -static_cast(tmpvar.AllocatedBytes()));