Avoid allocating ScopedStepContainer for each Run
We avoid recreating a ScopedStepContainer by storing one for reuse in the KernelAndDeviceOp & KernelAndDeviceFunc classes. Further, we can avoid doing a resource manager lookup to perform a clean by adding a dirty flag to indicate the ScopedStepContainer was accessed. In addition, we simplify the signature of MakeResourceHandle by avoiding the need to pass in the entire OpKernelContext object. PiperOrigin-RevId: 281110991 Change-Id: I0a186583a1ff50b08bf68c18cfb99c912e05386d
This commit is contained in:
parent
87451d7147
commit
309a3c7964
@ -127,9 +127,9 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
step_id, [&status, device](const string& name) {
|
||||
status = device->resource_manager()->Cleanup(name);
|
||||
});
|
||||
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
|
||||
step_container->name(), XlaContext::kXlaContextResourceName,
|
||||
xla_context));
|
||||
TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
|
||||
XlaContext::kXlaContextResourceName,
|
||||
xla_context));
|
||||
|
||||
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
|
||||
TF_RETURN_IF_ERROR(graph_compiler.Compile());
|
||||
|
@ -45,8 +45,8 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
|
||||
// per-step context is looked up in the resource manager. The
|
||||
// JIT will prepopulate the JITContext.
|
||||
XlaContext* context;
|
||||
TF_CHECK_OK(ctx->resource_manager()->Lookup(
|
||||
ctx->step_container()->name(), kXlaContextResourceName, &context));
|
||||
TF_CHECK_OK(ctx->step_container()->Lookup(ctx->resource_manager(),
|
||||
kXlaContextResourceName, &context));
|
||||
// The resource manager handed us a fresh reference to 'context', but retains
|
||||
// a reference itself so the context won't be freed. The resource manager will
|
||||
// outlive the JIT compilation.
|
||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/lib/scoped_annotation.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
@ -207,25 +206,20 @@ Status KernelAndDeviceOp::Run(
|
||||
const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
|
||||
CancellationManager* cancellation_manager,
|
||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
|
||||
ScopedStepContainer step_container(0, [this](const string& name) {
|
||||
device_->resource_manager()->Cleanup(name).IgnoreError();
|
||||
});
|
||||
return this->Run(&step_container, inputs, outputs, cancellation_manager,
|
||||
remote_func_params);
|
||||
Status s = this->Run(&step_container_, inputs, outputs, cancellation_manager,
|
||||
remote_func_params);
|
||||
step_container_.CleanUp();
|
||||
return s;
|
||||
}
|
||||
|
||||
Status KernelAndDeviceFunc::Run(
|
||||
const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
|
||||
CancellationManager* cancellation_manager,
|
||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
|
||||
const std::vector<Device*> devices = pflr_->device_mgr()->ListDevices();
|
||||
ScopedStepContainer step_container(0, [&devices](const string& name) {
|
||||
for (Device* device : devices) {
|
||||
device->resource_manager()->Cleanup(name).IgnoreError();
|
||||
}
|
||||
});
|
||||
return this->Run(&step_container, inputs, outputs, cancellation_manager,
|
||||
remote_func_params);
|
||||
Status s = this->Run(&step_container_, inputs, outputs, cancellation_manager,
|
||||
remote_func_params);
|
||||
step_container_.CleanUp();
|
||||
return s;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -175,7 +175,10 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
||||
host_cpu_device),
|
||||
rendez_(rendez),
|
||||
log_memory_(log_memory),
|
||||
compile_with_xla_(compile_with_xla) {}
|
||||
compile_with_xla_(compile_with_xla),
|
||||
step_container_(0, [this](const string& name) {
|
||||
device_->resource_manager()->Cleanup(name).IgnoreError();
|
||||
}) {}
|
||||
|
||||
~KernelAndDeviceOp() override {}
|
||||
|
||||
@ -212,6 +215,8 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
||||
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
||||
const bool log_memory_;
|
||||
const bool compile_with_xla_;
|
||||
|
||||
ScopedStepContainer step_container_;
|
||||
};
|
||||
|
||||
// Represents a multi-device function. Functions can also be run using
|
||||
@ -241,9 +246,17 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
std::move(input_resource_dtypes_and_shapes)),
|
||||
name_(name),
|
||||
rendezvous_creator_(std::move(rendezvous_creator)),
|
||||
get_op_id_(std::move(get_op_id)) {}
|
||||
get_op_id_(std::move(get_op_id)),
|
||||
step_container_(0, [this](const string& name) {
|
||||
// TODO(b/139809335): This does not properly clean up remote resources
|
||||
const std::vector<Device*> devices =
|
||||
pflr_->device_mgr()->ListDevices();
|
||||
for (Device* device : devices) {
|
||||
device->resource_manager()->Cleanup(name).IgnoreError();
|
||||
}
|
||||
}) {}
|
||||
|
||||
virtual ~KernelAndDeviceFunc();
|
||||
~KernelAndDeviceFunc() override;
|
||||
|
||||
bool IsFunction() override { return true; };
|
||||
|
||||
@ -295,6 +308,8 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||
std::function<int64()> get_op_id_;
|
||||
|
||||
ScopedStepContainer step_container_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -33,18 +33,13 @@ namespace tensorflow {
|
||||
static std::atomic<int64> current_id_;
|
||||
|
||||
ResourceHandle MakeResourceHandle(
|
||||
OpKernelContext* ctx, const string& container, const string& name,
|
||||
const string& container, const string& name, const DeviceBase& device,
|
||||
const TypeIndex& type_index,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) {
|
||||
ResourceHandle result;
|
||||
result.set_device(ctx->device()->attributes().name());
|
||||
result.set_device(device.name());
|
||||
string actual_container;
|
||||
if (!container.empty()) {
|
||||
actual_container = container;
|
||||
} else {
|
||||
actual_container = ctx->resource_manager()->default_container();
|
||||
}
|
||||
result.set_container(actual_container);
|
||||
result.set_container(container);
|
||||
if (name == ResourceHandle::ANONYMOUS_NAME) {
|
||||
result.set_name(strings::StrCat("_AnonymousVar", current_id_.fetch_add(1)));
|
||||
} else {
|
||||
@ -63,7 +58,7 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(output_index, TensorShape({}), &handle));
|
||||
handle->scalar<ResourceHandle>()() =
|
||||
MakeResourceHandle(context, container, name, type_index);
|
||||
MakeResourceHandle(container, name, *context->device(), type_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -93,26 +93,62 @@ class ScopedStepContainer {
|
||||
// prefix: optional string prefix to disambiguate step containers.
|
||||
ScopedStepContainer(const int64 step_id,
|
||||
std::function<void(const string&)> cleanup)
|
||||
: name_(strings::StrCat("__per_step_", step_id)),
|
||||
: container_(strings::StrCat("__per_step_", step_id)),
|
||||
step_id_(step_id),
|
||||
cleanup_(cleanup) {}
|
||||
cleanup_(cleanup),
|
||||
dirty_(false) {}
|
||||
|
||||
ScopedStepContainer(const int64 step_id,
|
||||
std::function<void(const string&)> cleanup,
|
||||
const string& prefix)
|
||||
: name_(strings::StrCat("__", prefix, "_per_step_", step_id)),
|
||||
: container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
|
||||
step_id_(step_id),
|
||||
cleanup_(cleanup) {}
|
||||
cleanup_(cleanup),
|
||||
dirty_(false) {}
|
||||
|
||||
~ScopedStepContainer() { cleanup_(name_); }
|
||||
~ScopedStepContainer() { CleanUp(); }
|
||||
|
||||
void CleanUp() {
|
||||
mutex_lock ml(mu_);
|
||||
if (dirty_) {
|
||||
cleanup_(container_);
|
||||
dirty_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Pass through functions for resource lookup and creation. We do this to
|
||||
// ensure that we can appropriately set the dirty_ bit in the
|
||||
// ScopedStepContainer if the name of the container is used to create
|
||||
// resources.
|
||||
|
||||
// Pass through to MakeResourceHandle with the container name
|
||||
template <typename T>
|
||||
ResourceHandle MakeResourceHandle(
|
||||
const string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
|
||||
// Pass through to ResourceMgr::Create with the container name
|
||||
template <typename T>
|
||||
Status Create(ResourceMgr* rm, const string& name,
|
||||
T* resource) TF_MUST_USE_RESULT;
|
||||
// Pass through to ResourceMgr::Delete with the container name
|
||||
template <typename T>
|
||||
Status Delete(ResourceMgr* rm, const string& name) TF_MUST_USE_RESULT;
|
||||
// Pass through to ResourceMgr::Lookup with the container name
|
||||
template <typename T>
|
||||
Status Lookup(ResourceMgr* rm, const string& name,
|
||||
T** resource) const TF_MUST_USE_RESULT;
|
||||
// Pass through to ResourceMgr::LookupOrCreate with the container name
|
||||
template <typename T>
|
||||
Status LookupOrCreate(ResourceMgr* rm, const string& name, T** resource,
|
||||
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
||||
|
||||
const string& name() const { return name_; }
|
||||
const int64 step_id() const { return step_id_; }
|
||||
|
||||
private:
|
||||
const string name_;
|
||||
const string container_;
|
||||
const int64 step_id_;
|
||||
const std::function<void(const string&)> cleanup_;
|
||||
mutex mu_;
|
||||
mutable bool dirty_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
class ResourceMgr {
|
||||
@ -255,26 +291,25 @@ class ResourceMgr {
|
||||
// Makes a resource handle with the specified type for a given container /
|
||||
// name.
|
||||
ResourceHandle MakeResourceHandle(
|
||||
OpKernelContext* ctx, const string& container, const string& name,
|
||||
const string& container, const string& name, const DeviceBase& device,
|
||||
const TypeIndex& type_index,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {});
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
template <typename T>
|
||||
ResourceHandle MakeResourceHandle(
|
||||
OpKernelContext* ctx, const string& container, const string& name,
|
||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
||||
return MakeResourceHandle(ctx, container, name, MakeTypeIndex<T>(),
|
||||
dtypes_and_shapes);
|
||||
return MakeResourceHandle(
|
||||
container.empty() ? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
|
||||
}
|
||||
|
||||
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
||||
const string& container, const string& name,
|
||||
const TypeIndex& type_index);
|
||||
|
||||
template <typename T>
|
||||
ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
|
||||
const string& name);
|
||||
|
||||
// Returns a resource handle from a numbered op input.
|
||||
const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
|
||||
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
|
||||
@ -660,12 +695,6 @@ Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
|
||||
return ctx->resource_manager()->Lookup(container, shared_name, resource);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
|
||||
const string& name) {
|
||||
return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name);
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
|
||||
@ -840,6 +869,43 @@ void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ResourceHandle ScopedStepContainer::MakeResourceHandle(
|
||||
const string& name, const DeviceBase& device) {
|
||||
mutex_lock ml(mu_);
|
||||
dirty_ = true;
|
||||
return tensorflow::MakeResourceHandle(container_, name, device,
|
||||
MakeTypeIndex<T>(), {});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ScopedStepContainer::Lookup(ResourceMgr* rm, const string& name,
|
||||
T** resource) const {
|
||||
return rm->Lookup<T>(container_, name, resource);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name,
|
||||
T** resource,
|
||||
std::function<Status(T**)> creator) {
|
||||
mutex_lock ml(mu_);
|
||||
dirty_ = true;
|
||||
return rm->LookupOrCreate<T>(container_, name, resource, creator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name,
|
||||
T* resource) {
|
||||
mutex_lock ml(mu_);
|
||||
dirty_ = true;
|
||||
return rm->Create<T>(container_, name, resource);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ScopedStepContainer::Delete(ResourceMgr* rm, const string& name) {
|
||||
return rm->Delete<T>(container_, name);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
|
||||
|
@ -238,6 +238,7 @@ class StubDevice : public DeviceBase {
|
||||
}
|
||||
|
||||
const DeviceAttributes& attributes() const override { return attr_; }
|
||||
const string& name() const override { return attr_.name(); }
|
||||
|
||||
private:
|
||||
DeviceAttributes attr_;
|
||||
|
@ -34,8 +34,8 @@ Status CreateHandle(OpKernelContext* ctx, T* resource,
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
|
||||
|
||||
*handle =
|
||||
MakeResourceHandle(ctx, container_name, unique_name, MakeTypeIndex<T>());
|
||||
*handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
|
||||
MakeTypeIndex<T>());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
|
||||
if (step_container == nullptr) {
|
||||
return errors::Internal("No step container.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
|
||||
TF_RETURN_IF_ERROR(step_container->Lookup(rm, key, stack));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -188,7 +188,7 @@ void StackOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES(ctx, step_container != nullptr,
|
||||
errors::Internal("No step container."));
|
||||
Stack* stack = new Stack(elem_type_, stack_name, size);
|
||||
OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack));
|
||||
OP_REQUIRES_OK(ctx, step_container->Create(rm, key, stack));
|
||||
if (IsRefType(ctx->expected_output_dtype(0))) {
|
||||
// Create the stack handle.
|
||||
AllocatorAttributes alloc_attr;
|
||||
@ -204,7 +204,7 @@ void StackOp::Compute(OpKernelContext* ctx) {
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->flat<ResourceHandle>()(0) =
|
||||
MakePerStepResourceHandle<Stack>(ctx, key);
|
||||
ctx->step_container()->MakeResourceHandle<Stack>(key, *ctx->device());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,7 +46,7 @@ Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
|
||||
return errors::InvalidArgument(
|
||||
"tensor_array::AddToTensor type not supported: ",
|
||||
DataTypeString(DataTypeToEnum<T>::value));
|
||||
};
|
||||
}
|
||||
|
||||
#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
|
||||
template <> \
|
||||
@ -74,7 +74,7 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
|
||||
return errors::InvalidArgument(
|
||||
"tensor_array::TensorSetZero type not supported: ",
|
||||
DataTypeString(DataTypeToEnum<T>::value));
|
||||
};
|
||||
}
|
||||
|
||||
#define TENSOR_ARRAY_SET_ZERO(Device, T) \
|
||||
template <> \
|
||||
@ -347,7 +347,8 @@ class TensorArray : public ResourceBase {
|
||||
Tensor* handle() { return &handle_; }
|
||||
|
||||
ResourceHandle resource_handle(OpKernelContext* ctx) {
|
||||
return MakePerStepResourceHandle<TensorArray>(ctx, key_);
|
||||
return ctx->step_container()->MakeResourceHandle<TensorArray>(
|
||||
key_, *ctx->device());
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -79,8 +79,8 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) {
|
||||
TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
if (rm == nullptr) return errors::Internal("No resource manager.");
|
||||
TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
|
||||
container + ta_handle, tensor_array));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->step_container()->Lookup(rm, container + ta_handle, tensor_array));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array);
|
||||
@ -209,8 +209,7 @@ class TensorArrayOp : public TensorArrayCreationOp {
|
||||
false /* multiple_writes_aggregate */, false /* is_grad */,
|
||||
-1 /* marked_size */, clear_after_read_);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
rm->Create(ctx->step_container()->name(), key, tensor_array));
|
||||
TF_RETURN_IF_ERROR(ctx->step_container()->Create(rm, key, tensor_array));
|
||||
|
||||
*output_tensor_array = tensor_array;
|
||||
|
||||
@ -306,9 +305,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
|
||||
output_handle(1) = strings::StrCat(tensor_array_name, "@", source_);
|
||||
|
||||
TensorArray* tensor_array;
|
||||
TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
|
||||
strings::StrCat(container, tensor_array_name),
|
||||
&tensor_array));
|
||||
TF_RETURN_IF_ERROR(ctx->step_container()->Lookup(
|
||||
rm, strings::StrCat(container, tensor_array_name), &tensor_array));
|
||||
core::ScopedUnref unref(tensor_array);
|
||||
|
||||
// Once gradients are being calculated, the forward TensorArray
|
||||
@ -364,8 +362,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
|
||||
return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend);
|
||||
};
|
||||
|
||||
Status s = rm->LookupOrCreate<TensorArray>(
|
||||
ctx->step_container()->name(), key, output_tensor_array, creator);
|
||||
Status s = ctx->step_container()->LookupOrCreate<TensorArray>(
|
||||
rm, key, output_tensor_array, creator);
|
||||
(*output_tensor_array)->Unref();
|
||||
|
||||
return s;
|
||||
|
@ -100,8 +100,8 @@ class TemporaryVariableOp : public OpKernel {
|
||||
s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
|
||||
if (!s.ok()) tmp_var->Unref();
|
||||
OP_REQUIRES_OK(context, s);
|
||||
OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
|
||||
var_name_, tmp_var));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->step_container()->Create(rm, var_name_, tmp_var));
|
||||
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
|
||||
if (context->track_allocations()) {
|
||||
context->record_persistent_memory_allocation(
|
||||
@ -145,8 +145,9 @@ class DestroyTemporaryVariableOp : public OpKernel {
|
||||
context->set_output(0, tmpvar);
|
||||
ResourceMgr* rm = context->resource_manager();
|
||||
OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
|
||||
OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
|
||||
context->step_container()->name(), var_name_));
|
||||
OP_REQUIRES_OK(
|
||||
context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>(
|
||||
rm, var_name_));
|
||||
if (context->track_allocations()) {
|
||||
context->record_persistent_memory_allocation(
|
||||
-static_cast<int64>(tmpvar.AllocatedBytes()));
|
||||
|
Loading…
Reference in New Issue
Block a user