Fixing MultiDeviceIterator memory leak issue in eager mode.

In Graph mode, we rely on MultiDeviceIteratorHandleOp destruction to decrement the ref count for the resource. Since we don't destroy kernels in Eager mode, we explicitly added in a destroy_resource_op to mitigate this.

The problem is that this isn't enough. The ResourceMgr.LookupOrCreate method ends up increasing the ref count of the resource by 2 and we were effectively doing two Unref's in graph mode in the destructor. So even with the destroy resource op, the refcount remained 1 and didn't go down to zero.

The fix here is to handle the Eager mode case separately, similar to what we've done with the AnonymousIteratorHandleOp. Instead of creating a whole new kernel, we re-use the existing kernel and use a special shared_name argument to identify when to switch the behavior. Now in Eager mode, after running the HandleOp kernel, the refcount of the resource is 1.

PiperOrigin-RevId: 231333966
This commit is contained in:
Rohan Jain 2019-01-28 19:45:17 -08:00 committed by TensorFlower Gardener
parent 601921925d
commit ad65038c90
3 changed files with 94 additions and 23 deletions

View File

@ -359,6 +359,9 @@ class MultiDeviceIterator : public ResourceBase {
std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
};
// Used to generate unique names for anonymous multi device iterators.
static std::atomic<int64> current_id_;
// Just creates a MultiDeviceIterator and returns it.
class MultiDeviceIteratorHandleOp : public OpKernel {
public:
@ -388,6 +391,8 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
}
void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
string unique_name = cinfo_.name();
string container_name = cinfo_.container();
{
mutex_lock l(mu_);
if (resource_ == nullptr) {
@ -402,31 +407,49 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
MultiDeviceIterator* resource;
OP_REQUIRES_OK(context,
mgr->LookupOrCreate<MultiDeviceIterator>(
cinfo_.container(), cinfo_.name(), &resource,
[this, lib, &flib_def, &pflr,
&function_handle_cache](MultiDeviceIterator** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new MultiDeviceIterator(
output_types_, output_shapes_, devices_,
std::move(flib_def), std::move(pflr), lib,
std::move(function_handle_cache));
return Status::OK();
}));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
resource->Unref();
context->SetStatus(s);
return;
if (name_ == ResourceHandle::ANONYMOUS_NAME) {
unique_name = strings::StrCat("_AnonymousMultiDeviceIterator",
current_id_.fetch_add(1));
container_name = "AnonymousMultiDeviceIterator";
resource = new MultiDeviceIterator(
output_types_, output_shapes_, devices_, std::move(flib_def),
std::move(pflr), lib, std::move(function_handle_cache));
OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
container_name, unique_name, resource));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
resource->Unref();
context->SetStatus(s);
return;
}
} else {
unique_name = cinfo_.name();
container_name = cinfo_.container();
OP_REQUIRES_OK(context,
mgr->LookupOrCreate<MultiDeviceIterator>(
container_name, unique_name, &resource,
[this, lib, &flib_def, &pflr,
&function_handle_cache](MultiDeviceIterator** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new MultiDeviceIterator(
output_types_, output_shapes_, devices_,
std::move(flib_def), std::move(pflr),
lib, std::move(function_handle_cache));
return Status::OK();
}));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
resource->Unref();
context->SetStatus(s);
return;
}
resource_ = resource;
}
resource_ = resource;
}
}
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
context, 0, container_name, unique_name,
MakeTypeIndex<MultiDeviceIterator>()));
}

View File

@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl.testing import parameterized
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import optimization
@ -34,10 +36,40 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
# memory_profiler might not be available in the OSS version of TensorFlow.
try:
import memory_profiler # pylint:disable=g-import-not-at-top
except ImportError:
memory_profiler = None
@test_util.run_all_in_graph_and_eager_modes
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
parameterized.TestCase):
def assertNotIncreasingMemory(self,
f,
num_iters=100000,
increase_threshold_absolute_mb=10):
"""Assert memory usage doesn't increase beyond given threshold for f."""
with context.eager_mode():
# Warm up.
f()
# Wait for background threads to start up and take over memory.
# FIXME: The nature of this test leaves few other options. Maybe there
# is a better way to do this.
time.sleep(4)
initial = memory_profiler.memory_usage(-1)[0]
for _ in six.moves.range(num_iters):
f()
increase = memory_profiler.memory_usage(-1)[0] - initial
assert increase < increase_threshold_absolute_mb, (
"Increase is too high. Initial memory usage: %f MB. Increase: %f MB. "
"Maximum allowed increase: %f") % (initial, increase,
increase_threshold_absolute_mb)
@parameterized.parameters(0, 1, 42,)
@test_util.run_v1_only("b/121264236")
def testInitOnly(self, num_inits):
@ -68,6 +100,24 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@test_util.run_v1_only("b/121264236")
def testEagerNoMemoryLeak(self):
if not context.executing_eagerly():
self.skipTest("Only eager mode test")
if memory_profiler is None:
self.skipTest("memory_profiler required to run this test")
def f():
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
dataset, ["/cpu:1", "/cpu:2"])
self.evaluate(multi_device_iterator.get_next())
del multi_device_iterator
del dataset
self.assertNotIncreasingMemory(
f, num_iters=100, increase_threshold_absolute_mb=175)
@test_util.run_v1_only("b/121264236")
def testOneOnSameDevice(self):
with ops.device("/cpu:0"):

View File

@ -158,9 +158,7 @@ class MultiDeviceIterator(object):
# TODO(b/121378567): Get rid of this shared_name hack.
shared_name = ""
if context.executing_eagerly():
# Ensure a unique name when eager execution is enabled to avoid spurious
# sharing issues.
shared_name += str(ops.uid())
shared_name = context.shared_name()
self._multi_device_iterator_resource = (
gen_dataset_ops.multi_device_iterator(
devices=self._devices,