Avoid allocating ScopedStepContainer for each Run

We avoid recreating a ScopedStepContainer by storing one for reuse in
the KernelAndDeviceOp & KernelAndDeviceFunc classes. Further, we can
avoid doing a resource manager lookup to perform a clean by adding a
dirty flag to indicate the ScopedStepContainer was accessed.

In addition, we simplify the signature of MakeResourceHandle by avoiding
the need to pass in the entire OpKernelContext object.

PiperOrigin-RevId: 281110991
Change-Id: I0a186583a1ff50b08bf68c18cfb99c912e05386d
This commit is contained in:
Gaurav Jain 2019-11-18 11:14:44 -08:00 committed by TensorFlower Gardener
parent 87451d7147
commit 309a3c7964
12 changed files with 144 additions and 73 deletions

View File

@ -127,9 +127,9 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
XlaContext::kXlaContextResourceName,
xla_context));
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());

View File

@ -45,8 +45,8 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
// per-step context is looked up in the resource manager. The
// JIT will prepopulate the JITContext.
XlaContext* context;
TF_CHECK_OK(ctx->resource_manager()->Lookup(
ctx->step_container()->name(), kXlaContextResourceName, &context));
TF_CHECK_OK(ctx->step_container()->Lookup(ctx->resource_manager(),
kXlaContextResourceName, &context));
// The resource manager handed us a fresh reference to 'context', but retains
// a reference itself so the context won't be freed. The resource manager will
// outlive the JIT compilation.

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/lib/scoped_annotation.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/version.h"
@ -207,25 +206,20 @@ Status KernelAndDeviceOp::Run(
const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
ScopedStepContainer step_container(0, [this](const string& name) {
device_->resource_manager()->Cleanup(name).IgnoreError();
});
return this->Run(&step_container, inputs, outputs, cancellation_manager,
remote_func_params);
Status s = this->Run(&step_container_, inputs, outputs, cancellation_manager,
remote_func_params);
step_container_.CleanUp();
return s;
}
Status KernelAndDeviceFunc::Run(
const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
const std::vector<Device*> devices = pflr_->device_mgr()->ListDevices();
ScopedStepContainer step_container(0, [&devices](const string& name) {
for (Device* device : devices) {
device->resource_manager()->Cleanup(name).IgnoreError();
}
});
return this->Run(&step_container, inputs, outputs, cancellation_manager,
remote_func_params);
Status s = this->Run(&step_container_, inputs, outputs, cancellation_manager,
remote_func_params);
step_container_.CleanUp();
return s;
}
namespace {

View File

@ -175,7 +175,10 @@ class KernelAndDeviceOp final : public KernelAndDevice {
host_cpu_device),
rendez_(rendez),
log_memory_(log_memory),
compile_with_xla_(compile_with_xla) {}
compile_with_xla_(compile_with_xla),
step_container_(0, [this](const string& name) {
device_->resource_manager()->Cleanup(name).IgnoreError();
}) {}
~KernelAndDeviceOp() override {}
@ -212,6 +215,8 @@ class KernelAndDeviceOp final : public KernelAndDevice {
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
const bool log_memory_;
const bool compile_with_xla_;
ScopedStepContainer step_container_;
};
// Represents a multi-device function. Functions can also be run using
@ -241,9 +246,17 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
std::move(input_resource_dtypes_and_shapes)),
name_(name),
rendezvous_creator_(std::move(rendezvous_creator)),
get_op_id_(std::move(get_op_id)) {}
get_op_id_(std::move(get_op_id)),
step_container_(0, [this](const string& name) {
// TODO(b/139809335): This does not properly clean up remote resources
const std::vector<Device*> devices =
pflr_->device_mgr()->ListDevices();
for (Device* device : devices) {
device->resource_manager()->Cleanup(name).IgnoreError();
}
}) {}
virtual ~KernelAndDeviceFunc();
~KernelAndDeviceFunc() override;
bool IsFunction() override { return true; };
@ -295,6 +308,8 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
std::function<Rendezvous*(const int64)> rendezvous_creator_;
std::function<int64()> get_op_id_;
ScopedStepContainer step_container_;
};
} // namespace tensorflow

View File

@ -33,18 +33,13 @@ namespace tensorflow {
static std::atomic<int64> current_id_;
ResourceHandle MakeResourceHandle(
OpKernelContext* ctx, const string& container, const string& name,
const string& container, const string& name, const DeviceBase& device,
const TypeIndex& type_index,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) {
ResourceHandle result;
result.set_device(ctx->device()->attributes().name());
result.set_device(device.name());
string actual_container;
if (!container.empty()) {
actual_container = container;
} else {
actual_container = ctx->resource_manager()->default_container();
}
result.set_container(actual_container);
result.set_container(container);
if (name == ResourceHandle::ANONYMOUS_NAME) {
result.set_name(strings::StrCat("_AnonymousVar", current_id_.fetch_add(1)));
} else {
@ -63,7 +58,7 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
TF_RETURN_IF_ERROR(
context->allocate_output(output_index, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() =
MakeResourceHandle(context, container, name, type_index);
MakeResourceHandle(container, name, *context->device(), type_index);
return Status::OK();
}

View File

@ -93,26 +93,62 @@ class ScopedStepContainer {
// prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64 step_id,
std::function<void(const string&)> cleanup)
: name_(strings::StrCat("__per_step_", step_id)),
: container_(strings::StrCat("__per_step_", step_id)),
step_id_(step_id),
cleanup_(cleanup) {}
cleanup_(cleanup),
dirty_(false) {}
ScopedStepContainer(const int64 step_id,
std::function<void(const string&)> cleanup,
const string& prefix)
: name_(strings::StrCat("__", prefix, "_per_step_", step_id)),
: container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
step_id_(step_id),
cleanup_(cleanup) {}
cleanup_(cleanup),
dirty_(false) {}
~ScopedStepContainer() { cleanup_(name_); }
~ScopedStepContainer() { CleanUp(); }
void CleanUp() {
mutex_lock ml(mu_);
if (dirty_) {
cleanup_(container_);
dirty_ = false;
}
}
// Pass through functions for resource lookup and creation. We do this to
// ensure that we can appropriately set the dirty_ bit in the
// ScopedStepContainer if the name of the container is used to create
// resources.
// Pass through to MakeResourceHandle with the container name
template <typename T>
ResourceHandle MakeResourceHandle(
const string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
// Pass through to ResourceMgr::Create with the container name
template <typename T>
Status Create(ResourceMgr* rm, const string& name,
T* resource) TF_MUST_USE_RESULT;
// Pass through to ResourceMgr::Delete with the container name
template <typename T>
Status Delete(ResourceMgr* rm, const string& name) TF_MUST_USE_RESULT;
// Pass through to ResourceMgr::Lookup with the container name
template <typename T>
Status Lookup(ResourceMgr* rm, const string& name,
T** resource) const TF_MUST_USE_RESULT;
// Pass through to ResourceMgr::LookupOrCreate with the container name
template <typename T>
Status LookupOrCreate(ResourceMgr* rm, const string& name, T** resource,
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
const string& name() const { return name_; }
const int64 step_id() const { return step_id_; }
private:
const string name_;
const string container_;
const int64 step_id_;
const std::function<void(const string&)> cleanup_;
mutex mu_;
mutable bool dirty_ GUARDED_BY(mu_);
};
class ResourceMgr {
@ -255,26 +291,25 @@ class ResourceMgr {
// Makes a resource handle with the specified type for a given container /
// name.
ResourceHandle MakeResourceHandle(
OpKernelContext* ctx, const string& container, const string& name,
const string& container, const string& name, const DeviceBase& device,
const TypeIndex& type_index,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {});
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
TF_MUST_USE_RESULT;
template <typename T>
ResourceHandle MakeResourceHandle(
OpKernelContext* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
return MakeResourceHandle(ctx, container, name, MakeTypeIndex<T>(),
dtypes_and_shapes);
return MakeResourceHandle(
container.empty() ? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
}
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
const string& container, const string& name,
const TypeIndex& type_index);
template <typename T>
ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
const string& name);
// Returns a resource handle from a numbered op input.
const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
@ -660,12 +695,6 @@ Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
return ctx->resource_manager()->Lookup(container, shared_name, resource);
}
template <typename T>
ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
const string& name) {
return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name);
}
namespace internal {
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
@ -840,6 +869,43 @@ void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
}
}
template <typename T>
ResourceHandle ScopedStepContainer::MakeResourceHandle(
const string& name, const DeviceBase& device) {
mutex_lock ml(mu_);
dirty_ = true;
return tensorflow::MakeResourceHandle(container_, name, device,
MakeTypeIndex<T>(), {});
}
template <typename T>
Status ScopedStepContainer::Lookup(ResourceMgr* rm, const string& name,
T** resource) const {
return rm->Lookup<T>(container_, name, resource);
}
template <typename T>
Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name,
T** resource,
std::function<Status(T**)> creator) {
mutex_lock ml(mu_);
dirty_ = true;
return rm->LookupOrCreate<T>(container_, name, resource, creator);
}
template <typename T>
Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name,
T* resource) {
mutex_lock ml(mu_);
dirty_ = true;
return rm->Create<T>(container_, name, resource);
}
template <typename T>
Status ScopedStepContainer::Delete(ResourceMgr* rm, const string& name) {
return rm->Delete<T>(container_, name);
}
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_

View File

@ -238,6 +238,7 @@ class StubDevice : public DeviceBase {
}
const DeviceAttributes& attributes() const override { return attr_; }
const string& name() const override { return attr_.name(); }
private:
DeviceAttributes attr_;

View File

@ -34,8 +34,8 @@ Status CreateHandle(OpKernelContext* ctx, T* resource,
ResourceMgr* mgr = ctx->resource_manager();
TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
*handle =
MakeResourceHandle(ctx, container_name, unique_name, MakeTypeIndex<T>());
*handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
MakeTypeIndex<T>());
return Status::OK();
}

View File

@ -145,7 +145,7 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
if (step_container == nullptr) {
return errors::Internal("No step container.");
}
TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
TF_RETURN_IF_ERROR(step_container->Lookup(rm, key, stack));
return Status::OK();
}
}
@ -188,7 +188,7 @@ void StackOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES(ctx, step_container != nullptr,
errors::Internal("No step container."));
Stack* stack = new Stack(elem_type_, stack_name, size);
OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack));
OP_REQUIRES_OK(ctx, step_container->Create(rm, key, stack));
if (IsRefType(ctx->expected_output_dtype(0))) {
// Create the stack handle.
AllocatorAttributes alloc_attr;
@ -204,7 +204,7 @@ void StackOp::Compute(OpKernelContext* ctx) {
Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->flat<ResourceHandle>()(0) =
MakePerStepResourceHandle<Stack>(ctx, key);
ctx->step_container()->MakeResourceHandle<Stack>(key, *ctx->device());
}
}

View File

@ -46,7 +46,7 @@ Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
return errors::InvalidArgument(
"tensor_array::AddToTensor type not supported: ",
DataTypeString(DataTypeToEnum<T>::value));
};
}
#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
template <> \
@ -74,7 +74,7 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
return errors::InvalidArgument(
"tensor_array::TensorSetZero type not supported: ",
DataTypeString(DataTypeToEnum<T>::value));
};
}
#define TENSOR_ARRAY_SET_ZERO(Device, T) \
template <> \
@ -347,7 +347,8 @@ class TensorArray : public ResourceBase {
Tensor* handle() { return &handle_; }
ResourceHandle resource_handle(OpKernelContext* ctx) {
return MakePerStepResourceHandle<TensorArray>(ctx, key_);
return ctx->step_container()->MakeResourceHandle<TensorArray>(
key_, *ctx->device());
}
private:

View File

@ -79,8 +79,8 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) {
TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
ResourceMgr* rm = ctx->resource_manager();
if (rm == nullptr) return errors::Internal("No resource manager.");
TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
container + ta_handle, tensor_array));
TF_RETURN_IF_ERROR(
ctx->step_container()->Lookup(rm, container + ta_handle, tensor_array));
return Status::OK();
} else {
return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array);
@ -209,8 +209,7 @@ class TensorArrayOp : public TensorArrayCreationOp {
false /* multiple_writes_aggregate */, false /* is_grad */,
-1 /* marked_size */, clear_after_read_);
TF_RETURN_IF_ERROR(
rm->Create(ctx->step_container()->name(), key, tensor_array));
TF_RETURN_IF_ERROR(ctx->step_container()->Create(rm, key, tensor_array));
*output_tensor_array = tensor_array;
@ -306,9 +305,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
output_handle(1) = strings::StrCat(tensor_array_name, "@", source_);
TensorArray* tensor_array;
TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
strings::StrCat(container, tensor_array_name),
&tensor_array));
TF_RETURN_IF_ERROR(ctx->step_container()->Lookup(
rm, strings::StrCat(container, tensor_array_name), &tensor_array));
core::ScopedUnref unref(tensor_array);
// Once gradients are being calculated, the forward TensorArray
@ -364,8 +362,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend);
};
Status s = rm->LookupOrCreate<TensorArray>(
ctx->step_container()->name(), key, output_tensor_array, creator);
Status s = ctx->step_container()->LookupOrCreate<TensorArray>(
rm, key, output_tensor_array, creator);
(*output_tensor_array)->Unref();
return s;

View File

@ -100,8 +100,8 @@ class TemporaryVariableOp : public OpKernel {
s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
if (!s.ok()) tmp_var->Unref();
OP_REQUIRES_OK(context, s);
OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
var_name_, tmp_var));
OP_REQUIRES_OK(context,
context->step_container()->Create(rm, var_name_, tmp_var));
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
if (context->track_allocations()) {
context->record_persistent_memory_allocation(
@ -145,8 +145,9 @@ class DestroyTemporaryVariableOp : public OpKernel {
context->set_output(0, tmpvar);
ResourceMgr* rm = context->resource_manager();
OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
context->step_container()->name(), var_name_));
OP_REQUIRES_OK(
context, context->step_container()->Delete<TemporaryVariableOp::TmpVar>(
rm, var_name_));
if (context->track_allocations()) {
context->record_persistent_memory_allocation(
-static_cast<int64>(tmpvar.AllocatedBytes()));