Fixing MultiDeviceIterator memory leak issue in eager mode.
In Graph mode, we rely on MultiDeviceIteratorHandleOp destruction to decrement the ref count for the resource. Since we don't destroy kernels in Eager mode, we explicitly added in a destroy_resource_op to mitigate this. The problem is that this isn't enough. The ResourceMgr.LookupOrCreate method ends up increasing the ref count of the resource by 2 and we were effectively doing two Unref's in graph mode in the destructor. So even with the destroy resource op, the refcount remained 1 and didn't go down to zero. The fix here is to handle the Eager mode case separately, similar to what we've done with the AnonymousIteratorHandleOp. Instead of creating a whole new kernel, we re-use the existing kernel and use a special shared_name argument to identify when to switch the behavior. Now in Eager mode, after running the HandleOp kernel, the refcount of the resource is 1. PiperOrigin-RevId: 231333966
This commit is contained in:
parent
601921925d
commit
ad65038c90
@ -359,6 +359,9 @@ class MultiDeviceIterator : public ResourceBase {
|
||||
std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Used to generate unique names for anonymous multi device iterators.
|
||||
static std::atomic<int64> current_id_;
|
||||
|
||||
// Just creates a MultiDeviceIterator and returns it.
|
||||
class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
public:
|
||||
@ -388,6 +391,8 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
|
||||
string unique_name = cinfo_.name();
|
||||
string container_name = cinfo_.container();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (resource_ == nullptr) {
|
||||
@ -402,31 +407,49 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
|
||||
|
||||
MultiDeviceIterator* resource;
|
||||
OP_REQUIRES_OK(context,
|
||||
mgr->LookupOrCreate<MultiDeviceIterator>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this, lib, &flib_def, &pflr,
|
||||
&function_handle_cache](MultiDeviceIterator** ret)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new MultiDeviceIterator(
|
||||
output_types_, output_shapes_, devices_,
|
||||
std::move(flib_def), std::move(pflr), lib,
|
||||
std::move(function_handle_cache));
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
Status s = VerifyResource(resource);
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
resource->Unref();
|
||||
context->SetStatus(s);
|
||||
return;
|
||||
if (name_ == ResourceHandle::ANONYMOUS_NAME) {
|
||||
unique_name = strings::StrCat("_AnonymousMultiDeviceIterator",
|
||||
current_id_.fetch_add(1));
|
||||
container_name = "AnonymousMultiDeviceIterator";
|
||||
resource = new MultiDeviceIterator(
|
||||
output_types_, output_shapes_, devices_, std::move(flib_def),
|
||||
std::move(pflr), lib, std::move(function_handle_cache));
|
||||
OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
|
||||
container_name, unique_name, resource));
|
||||
Status s = VerifyResource(resource);
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
resource->Unref();
|
||||
context->SetStatus(s);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
unique_name = cinfo_.name();
|
||||
container_name = cinfo_.container();
|
||||
OP_REQUIRES_OK(context,
|
||||
mgr->LookupOrCreate<MultiDeviceIterator>(
|
||||
container_name, unique_name, &resource,
|
||||
[this, lib, &flib_def, &pflr,
|
||||
&function_handle_cache](MultiDeviceIterator** ret)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new MultiDeviceIterator(
|
||||
output_types_, output_shapes_, devices_,
|
||||
std::move(flib_def), std::move(pflr),
|
||||
lib, std::move(function_handle_cache));
|
||||
return Status::OK();
|
||||
}));
|
||||
Status s = VerifyResource(resource);
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
resource->Unref();
|
||||
context->SetStatus(s);
|
||||
return;
|
||||
}
|
||||
resource_ = resource;
|
||||
}
|
||||
|
||||
resource_ = resource;
|
||||
}
|
||||
}
|
||||
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
|
||||
context, 0, cinfo_.container(), cinfo_.name(),
|
||||
context, 0, container_name, unique_name,
|
||||
MakeTypeIndex<MultiDeviceIterator>()));
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
from absl.testing import parameterized
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
@ -34,10 +36,40 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# memory_profiler might not be available in the OSS version of TensorFlow.
|
||||
try:
|
||||
import memory_profiler # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
memory_profiler = None
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertNotIncreasingMemory(self,
|
||||
f,
|
||||
num_iters=100000,
|
||||
increase_threshold_absolute_mb=10):
|
||||
"""Assert memory usage doesn't increase beyond given threshold for f."""
|
||||
|
||||
with context.eager_mode():
|
||||
# Warm up.
|
||||
f()
|
||||
|
||||
# Wait for background threads to start up and take over memory.
|
||||
# FIXME: The nature of this test leaves few other options. Maybe there
|
||||
# is a better way to do this.
|
||||
time.sleep(4)
|
||||
initial = memory_profiler.memory_usage(-1)[0]
|
||||
for _ in six.moves.range(num_iters):
|
||||
f()
|
||||
increase = memory_profiler.memory_usage(-1)[0] - initial
|
||||
assert increase < increase_threshold_absolute_mb, (
|
||||
"Increase is too high. Initial memory usage: %f MB. Increase: %f MB. "
|
||||
"Maximum allowed increase: %f") % (initial, increase,
|
||||
increase_threshold_absolute_mb)
|
||||
|
||||
@parameterized.parameters(0, 1, 42,)
|
||||
@test_util.run_v1_only("b/121264236")
|
||||
def testInitOnly(self, num_inits):
|
||||
@ -68,6 +100,24 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@test_util.run_v1_only("b/121264236")
|
||||
def testEagerNoMemoryLeak(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("Only eager mode test")
|
||||
if memory_profiler is None:
|
||||
self.skipTest("memory_profiler required to run this test")
|
||||
|
||||
def f():
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
dataset, ["/cpu:1", "/cpu:2"])
|
||||
self.evaluate(multi_device_iterator.get_next())
|
||||
del multi_device_iterator
|
||||
del dataset
|
||||
|
||||
self.assertNotIncreasingMemory(
|
||||
f, num_iters=100, increase_threshold_absolute_mb=175)
|
||||
|
||||
@test_util.run_v1_only("b/121264236")
|
||||
def testOneOnSameDevice(self):
|
||||
with ops.device("/cpu:0"):
|
||||
|
||||
@ -158,9 +158,7 @@ class MultiDeviceIterator(object):
|
||||
# TODO(b/121378567): Get rid of this shared_name hack.
|
||||
shared_name = ""
|
||||
if context.executing_eagerly():
|
||||
# Ensure a unique name when eager execution is enabled to avoid spurious
|
||||
# sharing issues.
|
||||
shared_name += str(ops.uid())
|
||||
shared_name = context.shared_name()
|
||||
self._multi_device_iterator_resource = (
|
||||
gen_dataset_ops.multi_device_iterator(
|
||||
devices=self._devices,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user