Avoid the raw input list to avoid handle leaks

Ad a result of this change, I believe a TensorHandle reference count
leak during remote copies was fixed.

PiperOrigin-RevId: 253736555
This commit is contained in:
Gaurav Jain 2019-06-18 00:19:28 -07:00 committed by TensorFlower Gardener
parent 222df6844f
commit 3557034977
5 changed files with 163 additions and 108 deletions

View File

@ -31,6 +31,12 @@ void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
attrs_.NumInputs(static_cast<int>(inputs_.size())); attrs_.NumInputs(static_cast<int>(inputs_.size()));
} }
void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
h->Ref();
inputs_[i]->Unref();
inputs_[i] = h;
}
void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) { void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
inputs_.push_back(h); inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size())); attrs_.NumInputs(static_cast<int>(inputs_.size()));

View File

@ -52,7 +52,9 @@ class EagerOperation {
MutableInputs() { MutableInputs() {
return &inputs_; return &inputs_;
} }
void AddInput(tensorflow::TensorHandle* h); void AddInput(tensorflow::TensorHandle* h);
void UpdateInput(int i, tensorflow::TensorHandle* h);
void ConsumeInput(tensorflow::TensorHandle* h); void ConsumeInput(tensorflow::TensorHandle* h);
const tensorflow::string& Name() const { return name_; } const tensorflow::string& Name() const { return name_; }

View File

@ -89,10 +89,9 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
} }
// This function expects *handle to point to an existing tensor handle. The // This function expects *handle to point to an existing tensor handle. The
// function will (maybe) update the *handle to be pointed to the newly copied // function will update the *handle to be pointed to the existing input tensor
// tensor handle. // handle or else the newly copied tensor handle. The existing handle will have
// // a Ref added, vs the new handle has a Ref due to being newly constructed.
// The passed in *handle will be Unreffed if it is replaced.
// //
// `op_device_name` is passed in explicitly because `op->device()` might be // `op_device_name` is passed in explicitly because `op->device()` might be
// unset and we might have selected some specific device to run this op on. // unset and we might have selected some specific device to run this op on.
@ -100,18 +99,25 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
const string& op_device_name, int i, const string& op_device_name, int i,
const Device* expected_input_device, const Device* expected_input_device,
RunMetadata* run_metadata, RunMetadata* run_metadata,
TensorHandle** handle) { TensorHandle** result) {
tensorflow::TensorHandle* handle = op->Inputs()[i];
EagerContext* ctx = op->EagerContext(); EagerContext* ctx = op->EagerContext();
Device* handle_device = (*handle)->device(); Device* handle_device = handle->device();
const Device* actual_device = const Device* actual_device =
handle_device == nullptr ? ctx->HostCPU() : handle_device; handle_device == nullptr ? ctx->HostCPU() : handle_device;
if (expected_input_device != actual_device) { if (expected_input_device == actual_device) {
// No copy was done, so the result is just the original handle with a Ref
handle->Ref();
*result = handle;
return Status::OK();
}
switch (ctx->GetDevicePlacementPolicy()) { switch (ctx->GetDevicePlacementPolicy()) {
case DEVICE_PLACEMENT_SILENT_FOR_INT32: case DEVICE_PLACEMENT_SILENT_FOR_INT32:
// TODO(xpan): See if we could bubble python related error up // TODO(xpan): See if we could bubble python related error up
// to python level. // to python level.
if ((*handle)->dtype == DT_INT32) { if (handle->dtype == DT_INT32) {
// Note: enabling silent copies of int32 tensors to match behavior // Note: enabling silent copies of int32 tensors to match behavior
// of graph mode. // of graph mode.
break; break;
@ -123,8 +129,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
" cannot compute ", " cannot compute ",
op->Name(), " as input #", i, " was expected to be on ", op->Name(), " as input #", i, " was expected to be on ",
expected_input_device->name(), " but is actually on ", expected_input_device->name(), " but is actually on ",
actual_device->name(), " (operation running on ", op_device_name, actual_device->name(), " (operation running on ", op_device_name, ")",
")",
" Tensors can be copied explicitly using:" " Tensors can be copied explicitly using:"
" `with tf.device(device_name): x = tf.identity(x)`" " `with tf.device(device_name): x = tf.identity(x)`"
" or transparently copied by using" " or transparently copied by using"
@ -132,10 +137,9 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
" Copying tensors between devices may slow down your model"); " Copying tensors between devices may slow down your model");
case DEVICE_PLACEMENT_WARN: case DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->Name() << " input #" << i LOG(WARNING) << "before computing " << op->Name() << " input #" << i
<< " was expected to be on " << " was expected to be on " << expected_input_device->name()
<< expected_input_device->name() << " but is actually on " << " but is actually on " << actual_device->name()
<< actual_device->name() << " (operation running on " << " (operation running on " << op_device_name
<< op_device_name
<< "). This triggers a copy which can be a performance " << "). This triggers a copy which can be a performance "
"bottleneck."; "bottleneck.";
break; break;
@ -147,7 +151,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
auto pre_time_nanos = Env::Default()->NowNanos(); auto pre_time_nanos = Env::Default()->NowNanos();
TensorHandle* result_handle = nullptr; TensorHandle* result_handle = nullptr;
Status status = EagerCopyToDevice( Status status = EagerCopyToDevice(
*handle, ctx, expected_input_device->name().c_str(), &result_handle); handle, ctx, expected_input_device->name().c_str(), &result_handle);
if (run_metadata != nullptr) { if (run_metadata != nullptr) {
auto* step_stats = run_metadata->mutable_step_stats(); auto* step_stats = run_metadata->mutable_step_stats();
MaybeInitializeStepStats(step_stats, ctx); MaybeInitializeStepStats(step_stats, ctx);
@ -156,8 +160,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
auto* dev_stats = step_stats->mutable_dev_stats(device_idx); auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
auto* node_stats = dev_stats->add_node_stats(); auto* node_stats = dev_stats->add_node_stats();
node_stats->set_node_name("_Send"); node_stats->set_node_name("_Send");
node_stats->set_all_start_micros(pre_time_nanos / node_stats->set_all_start_micros(pre_time_nanos / EnvTime::kMicrosToNanos);
EnvTime::kMicrosToNanos);
node_stats->set_all_start_nanos(pre_time_nanos); node_stats->set_all_start_nanos(pre_time_nanos);
int64 now_nanos = Env::Default()->NowNanos(); int64 now_nanos = Env::Default()->NowNanos();
node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) / node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
@ -169,15 +172,14 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
} }
if (!status.ok()) { if (!status.ok()) {
if (result_handle != nullptr) result_handle->Unref(); if (result_handle != nullptr) result_handle->Unref();
return errors::Internal( return errors::Internal("Failed copying input tensor from ",
"Failed copying input tensor from ", actual_device->name(), " to ", actual_device->name(), " to ",
expected_input_device->name(), " in order to run ", op->Name(), ": ", expected_input_device->name(), " in order to run ",
status.error_message()); op->Name(), ": ", status.error_message());
} }
(*handle)->Unref(); *result = result_handle;
*handle = result_handle;
}
return Status::OK(); return Status::OK();
} }
@ -195,10 +197,12 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx,
} }
for (int i = 0; i < op->Inputs().size(); ++i) { for (int i = 0; i < op->Inputs().size(); ++i) {
const Device* expected_device = kernel->InputDevice(i); const Device* expected_device = kernel->InputDevice(i);
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice( TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
op, op_device_name, i, expected_device, run_metadata, op, op_device_name, i, expected_device, run_metadata, &handle));
&((*op->MutableInputs())[i]))); op->UpdateInput(i, handle);
tensorflow::TensorHandle* handle = op->Inputs()[i]; // Unref handle since it has a ref as an input now
handle->Unref();
if (handle->dtype != kernel->input_type(i)) { if (handle->dtype != kernel->input_type(i)) {
return errors::InvalidArgument( return errors::InvalidArgument(
"cannot compute ", op->Name(), " as input #", i, "(zero-based)", "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
@ -503,9 +507,13 @@ Status EagerLocalExecute(EagerOperation* op,
for (int i = 0; i < op->Inputs().size(); i++) { for (int i = 0; i < op->Inputs().size(); i++) {
TensorHandle* input = op->Inputs()[i]; TensorHandle* input = op->Inputs()[i];
if (input->IsRemote()) { if (input->IsRemote()) {
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(EagerCopyToDevice( TF_RETURN_IF_ERROR(EagerCopyToDevice(
input, ctx, device == nullptr ? "" : device->name().c_str(), input, ctx, device == nullptr ? "" : device->name().c_str(),
&(*op->MutableInputs())[i])); &handle));
op->UpdateInput(i, handle);
// Unref handle since it has a ref as an input now
handle->Unref();
} }
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -763,9 +771,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// Always copy to the remote CPU so that the actual device can be // Always copy to the remote CPU so that the actual device can be
// correctly determined after the kernel is selected/instantiated, since // correctly determined after the kernel is selected/instantiated, since
// the op might have its inputs on host memory. // the op might have its inputs on host memory.
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice( TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
op, op->Device()->name(), i, remote_cpu_device, op, op->Device()->name(), i, remote_cpu_device,
/* run_metadata= */ nullptr, &(*op->MutableInputs())[i])); /* run_metadata= */ nullptr, &handle));
op->UpdateInput(i, handle);
// Unref handle since it has a ref as an input now
handle->Unref();
} }
tensorflow::TensorHandle* input = op->Inputs()[i]; tensorflow::TensorHandle* input = op->Inputs()[i];

View File

@ -545,6 +545,7 @@ cuda_py_test(
size = "medium", size = "medium",
srcs = ["memory_test.py"], srcs = ["memory_test.py"],
additional_deps = [ additional_deps = [
":remote",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
@ -553,6 +554,7 @@ cuda_py_test(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"@six_archive//:six", "@six_archive//:six",
], ],
shard_count = 4,
tags = [ tags = [
"optonly", # The test is too slow in non-opt mode "optonly", # The test is too slow in non-opt mode
], ],

View File

@ -24,6 +24,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import time import time
import six import six
@ -31,11 +32,14 @@ from tensorflow.python import keras
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.variables import Variable from tensorflow.python.ops.variables import Variable
from tensorflow.python.training import server_lib
# memory_profiler might not be available in the OSS version of TensorFlow. # memory_profiler might not be available in the OSS version of TensorFlow.
try: try:
@ -55,12 +59,7 @@ class SingleLayerNet(keras.Model):
return self.fc1(x) return self.fc1(x)
class MemoryTest(test.TestCase): def assert_no_leak(f, num_iters=100000, increase_threshold_absolute_mb=10):
def assertNotIncreasingMemory(self,
f,
num_iters=100000,
increase_threshold_absolute_mb=10):
"""Assert memory usage doesn't increase beyond given threshold for f.""" """Assert memory usage doesn't increase beyond given threshold for f."""
with context.eager_mode(): with context.eager_mode():
@ -84,6 +83,9 @@ class MemoryTest(test.TestCase):
"Maximum allowed increase: %f") % (initial, increase, "Maximum allowed increase: %f") % (initial, increase,
increase_threshold_absolute_mb) increase_threshold_absolute_mb)
class MemoryTest(test.TestCase):
def testMemoryLeakAnonymousVariable(self): def testMemoryLeakAnonymousVariable(self):
if memory_profiler is None: if memory_profiler is None:
self.skipTest("memory_profiler required to run this test") self.skipTest("memory_profiler required to run this test")
@ -92,7 +94,7 @@ class MemoryTest(test.TestCase):
inputs = Variable(array_ops.zeros([32, 100], dtypes.float32)) inputs = Variable(array_ops.zeros([32, 100], dtypes.float32))
del inputs del inputs
self.assertNotIncreasingMemory(f, num_iters=10000) assert_no_leak(f, num_iters=10000)
def testMemoryLeakInSimpleModelForwardOnly(self): def testMemoryLeakInSimpleModelForwardOnly(self):
if memory_profiler is None: if memory_profiler is None:
@ -105,7 +107,7 @@ class MemoryTest(test.TestCase):
with backprop.GradientTape(): with backprop.GradientTape():
net(inputs) net(inputs)
self.assertNotIncreasingMemory(f) assert_no_leak(f)
def testMemoryLeakInSimpleModelForwardAndBackward(self): def testMemoryLeakInSimpleModelForwardAndBackward(self):
if memory_profiler is None: if memory_profiler is None:
@ -122,7 +124,7 @@ class MemoryTest(test.TestCase):
del tape del tape
self.assertNotIncreasingMemory(f) assert_no_leak(f)
def testMemoryLeakInFunction(self): def testMemoryLeakInFunction(self):
if memory_profiler is None: if memory_profiler is None:
@ -136,8 +138,39 @@ class MemoryTest(test.TestCase):
graph(constant_op.constant(42)) graph(constant_op.constant(42))
self.assertNotIncreasingMemory( assert_no_leak(f, num_iters=1000, increase_threshold_absolute_mb=30)
f, num_iters=1000, increase_threshold_absolute_mb=20)
class RemoteWorkerMemoryTest(test.TestCase):
def __init__(self, method):
super(RemoteWorkerMemoryTest, self).__init__(method)
# used for remote worker tests
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
self._cached_server = server_lib.Server.create_local_server()
self._cached_server_target = self._cached_server.target[len("grpc://"):]
def testMemoryLeakInLocalCopy(self):
if memory_profiler is None:
self.skipTest("memory_profiler required to run this test")
remote.connect_to_remote_host(self._cached_server_target)
# Run a function locally with the input on a remote worker and ensure we
# do not leak a reference to the remote tensor.
@def_function.function
def local_func(i):
return i
def func():
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
x = array_ops.zeros([1000, 1000], dtypes.int32)
local_func(x)
assert_no_leak(func, num_iters=100, increase_threshold_absolute_mb=50)
if __name__ == "__main__": if __name__ == "__main__":