Avoid the raw input list to avoid handle leaks
Ad a result of this change, I believe a TensorHandle reference count leak during remote copies was fixed. PiperOrigin-RevId: 253736555
This commit is contained in:
parent
222df6844f
commit
3557034977
@ -31,6 +31,12 @@ void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
|
||||
void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
|
||||
h->Ref();
|
||||
inputs_[i]->Unref();
|
||||
inputs_[i] = h;
|
||||
}
|
||||
|
||||
void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
|
||||
inputs_.push_back(h);
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
|
@ -52,7 +52,9 @@ class EagerOperation {
|
||||
MutableInputs() {
|
||||
return &inputs_;
|
||||
}
|
||||
|
||||
void AddInput(tensorflow::TensorHandle* h);
|
||||
void UpdateInput(int i, tensorflow::TensorHandle* h);
|
||||
void ConsumeInput(tensorflow::TensorHandle* h);
|
||||
|
||||
const tensorflow::string& Name() const { return name_; }
|
||||
|
@ -89,10 +89,9 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
||||
}
|
||||
|
||||
// This function expects *handle to point to an existing tensor handle. The
|
||||
// function will (maybe) update the *handle to be pointed to the newly copied
|
||||
// tensor handle.
|
||||
//
|
||||
// The passed in *handle will be Unreffed if it is replaced.
|
||||
// function will update the *handle to be pointed to the existing input tensor
|
||||
// handle or else the newly copied tensor handle. The existing handle will have
|
||||
// a Ref added, vs the new handle has a Ref due to being newly constructed.
|
||||
//
|
||||
// `op_device_name` is passed in explicitly because `op->device()` might be
|
||||
// unset and we might have selected some specific device to run this op on.
|
||||
@ -100,18 +99,25 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
const string& op_device_name, int i,
|
||||
const Device* expected_input_device,
|
||||
RunMetadata* run_metadata,
|
||||
TensorHandle** handle) {
|
||||
TensorHandle** result) {
|
||||
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
||||
EagerContext* ctx = op->EagerContext();
|
||||
Device* handle_device = (*handle)->device();
|
||||
Device* handle_device = handle->device();
|
||||
const Device* actual_device =
|
||||
handle_device == nullptr ? ctx->HostCPU() : handle_device;
|
||||
|
||||
if (expected_input_device != actual_device) {
|
||||
if (expected_input_device == actual_device) {
|
||||
// No copy was done, so the result is just the original handle with a Ref
|
||||
handle->Ref();
|
||||
*result = handle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
switch (ctx->GetDevicePlacementPolicy()) {
|
||||
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
||||
// TODO(xpan): See if we could bubble python related error up
|
||||
// to python level.
|
||||
if ((*handle)->dtype == DT_INT32) {
|
||||
if (handle->dtype == DT_INT32) {
|
||||
// Note: enabling silent copies of int32 tensors to match behavior
|
||||
// of graph mode.
|
||||
break;
|
||||
@ -123,8 +129,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
" cannot compute ",
|
||||
op->Name(), " as input #", i, " was expected to be on ",
|
||||
expected_input_device->name(), " but is actually on ",
|
||||
actual_device->name(), " (operation running on ", op_device_name,
|
||||
")",
|
||||
actual_device->name(), " (operation running on ", op_device_name, ")",
|
||||
" Tensors can be copied explicitly using:"
|
||||
" `with tf.device(device_name): x = tf.identity(x)`"
|
||||
" or transparently copied by using"
|
||||
@ -132,10 +137,9 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
" Copying tensors between devices may slow down your model");
|
||||
case DEVICE_PLACEMENT_WARN:
|
||||
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
||||
<< " was expected to be on "
|
||||
<< expected_input_device->name() << " but is actually on "
|
||||
<< actual_device->name() << " (operation running on "
|
||||
<< op_device_name
|
||||
<< " was expected to be on " << expected_input_device->name()
|
||||
<< " but is actually on " << actual_device->name()
|
||||
<< " (operation running on " << op_device_name
|
||||
<< "). This triggers a copy which can be a performance "
|
||||
"bottleneck.";
|
||||
break;
|
||||
@ -147,7 +151,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
auto pre_time_nanos = Env::Default()->NowNanos();
|
||||
TensorHandle* result_handle = nullptr;
|
||||
Status status = EagerCopyToDevice(
|
||||
*handle, ctx, expected_input_device->name().c_str(), &result_handle);
|
||||
handle, ctx, expected_input_device->name().c_str(), &result_handle);
|
||||
if (run_metadata != nullptr) {
|
||||
auto* step_stats = run_metadata->mutable_step_stats();
|
||||
MaybeInitializeStepStats(step_stats, ctx);
|
||||
@ -156,8 +160,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
|
||||
auto* node_stats = dev_stats->add_node_stats();
|
||||
node_stats->set_node_name("_Send");
|
||||
node_stats->set_all_start_micros(pre_time_nanos /
|
||||
EnvTime::kMicrosToNanos);
|
||||
node_stats->set_all_start_micros(pre_time_nanos / EnvTime::kMicrosToNanos);
|
||||
node_stats->set_all_start_nanos(pre_time_nanos);
|
||||
int64 now_nanos = Env::Default()->NowNanos();
|
||||
node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
|
||||
@ -169,15 +172,14 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
}
|
||||
if (!status.ok()) {
|
||||
if (result_handle != nullptr) result_handle->Unref();
|
||||
return errors::Internal(
|
||||
"Failed copying input tensor from ", actual_device->name(), " to ",
|
||||
expected_input_device->name(), " in order to run ", op->Name(), ": ",
|
||||
status.error_message());
|
||||
return errors::Internal("Failed copying input tensor from ",
|
||||
actual_device->name(), " to ",
|
||||
expected_input_device->name(), " in order to run ",
|
||||
op->Name(), ": ", status.error_message());
|
||||
}
|
||||
|
||||
(*handle)->Unref();
|
||||
*handle = result_handle;
|
||||
}
|
||||
*result = result_handle;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -195,10 +197,12 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx,
|
||||
}
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
const Device* expected_device = kernel->InputDevice(i);
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||
op, op_device_name, i, expected_device, run_metadata,
|
||||
&((*op->MutableInputs())[i])));
|
||||
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
||||
op, op_device_name, i, expected_device, run_metadata, &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
if (handle->dtype != kernel->input_type(i)) {
|
||||
return errors::InvalidArgument(
|
||||
"cannot compute ", op->Name(), " as input #", i, "(zero-based)",
|
||||
@ -503,9 +507,13 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
TensorHandle* input = op->Inputs()[i];
|
||||
if (input->IsRemote()) {
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||
input, ctx, device == nullptr ? "" : device->name().c_str(),
|
||||
&(*op->MutableInputs())[i]));
|
||||
&handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -763,9 +771,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// Always copy to the remote CPU so that the actual device can be
|
||||
// correctly determined after the kernel is selected/instantiated, since
|
||||
// the op might have its inputs on host memory.
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||
op, op->Device()->name(), i, remote_cpu_device,
|
||||
/* run_metadata= */ nullptr, &(*op->MutableInputs())[i]));
|
||||
/* run_metadata= */ nullptr, &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
|
@ -545,6 +545,7 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["memory_test.py"],
|
||||
additional_deps = [
|
||||
":remote",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/eager:test",
|
||||
@ -553,6 +554,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"optonly", # The test is too slow in non-opt mode
|
||||
],
|
||||
|
@ -24,6 +24,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
import six
|
||||
|
||||
@ -31,11 +32,14 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.variables import Variable
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
# memory_profiler might not be available in the OSS version of TensorFlow.
|
||||
try:
|
||||
@ -55,12 +59,7 @@ class SingleLayerNet(keras.Model):
|
||||
return self.fc1(x)
|
||||
|
||||
|
||||
class MemoryTest(test.TestCase):
|
||||
|
||||
def assertNotIncreasingMemory(self,
|
||||
f,
|
||||
num_iters=100000,
|
||||
increase_threshold_absolute_mb=10):
|
||||
def assert_no_leak(f, num_iters=100000, increase_threshold_absolute_mb=10):
|
||||
"""Assert memory usage doesn't increase beyond given threshold for f."""
|
||||
|
||||
with context.eager_mode():
|
||||
@ -84,6 +83,9 @@ class MemoryTest(test.TestCase):
|
||||
"Maximum allowed increase: %f") % (initial, increase,
|
||||
increase_threshold_absolute_mb)
|
||||
|
||||
|
||||
class MemoryTest(test.TestCase):
|
||||
|
||||
def testMemoryLeakAnonymousVariable(self):
|
||||
if memory_profiler is None:
|
||||
self.skipTest("memory_profiler required to run this test")
|
||||
@ -92,7 +94,7 @@ class MemoryTest(test.TestCase):
|
||||
inputs = Variable(array_ops.zeros([32, 100], dtypes.float32))
|
||||
del inputs
|
||||
|
||||
self.assertNotIncreasingMemory(f, num_iters=10000)
|
||||
assert_no_leak(f, num_iters=10000)
|
||||
|
||||
def testMemoryLeakInSimpleModelForwardOnly(self):
|
||||
if memory_profiler is None:
|
||||
@ -105,7 +107,7 @@ class MemoryTest(test.TestCase):
|
||||
with backprop.GradientTape():
|
||||
net(inputs)
|
||||
|
||||
self.assertNotIncreasingMemory(f)
|
||||
assert_no_leak(f)
|
||||
|
||||
def testMemoryLeakInSimpleModelForwardAndBackward(self):
|
||||
if memory_profiler is None:
|
||||
@ -122,7 +124,7 @@ class MemoryTest(test.TestCase):
|
||||
|
||||
del tape
|
||||
|
||||
self.assertNotIncreasingMemory(f)
|
||||
assert_no_leak(f)
|
||||
|
||||
def testMemoryLeakInFunction(self):
|
||||
if memory_profiler is None:
|
||||
@ -136,8 +138,39 @@ class MemoryTest(test.TestCase):
|
||||
|
||||
graph(constant_op.constant(42))
|
||||
|
||||
self.assertNotIncreasingMemory(
|
||||
f, num_iters=1000, increase_threshold_absolute_mb=20)
|
||||
assert_no_leak(f, num_iters=1000, increase_threshold_absolute_mb=30)
|
||||
|
||||
|
||||
class RemoteWorkerMemoryTest(test.TestCase):
|
||||
|
||||
def __init__(self, method):
|
||||
super(RemoteWorkerMemoryTest, self).__init__(method)
|
||||
|
||||
# used for remote worker tests
|
||||
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||
self._cached_server = server_lib.Server.create_local_server()
|
||||
self._cached_server_target = self._cached_server.target[len("grpc://"):]
|
||||
|
||||
def testMemoryLeakInLocalCopy(self):
|
||||
if memory_profiler is None:
|
||||
self.skipTest("memory_profiler required to run this test")
|
||||
|
||||
remote.connect_to_remote_host(self._cached_server_target)
|
||||
|
||||
# Run a function locally with the input on a remote worker and ensure we
|
||||
# do not leak a reference to the remote tensor.
|
||||
|
||||
@def_function.function
|
||||
def local_func(i):
|
||||
return i
|
||||
|
||||
def func():
|
||||
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
|
||||
x = array_ops.zeros([1000, 1000], dtypes.int32)
|
||||
|
||||
local_func(x)
|
||||
|
||||
assert_no_leak(func, num_iters=100, increase_threshold_absolute_mb=50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user