Avoid the raw input list to avoid handle leaks
Ad a result of this change, I believe a TensorHandle reference count leak during remote copies was fixed. PiperOrigin-RevId: 253736555
This commit is contained in:
parent
222df6844f
commit
3557034977
@ -31,6 +31,12 @@ void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
|
|||||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
|
||||||
|
h->Ref();
|
||||||
|
inputs_[i]->Unref();
|
||||||
|
inputs_[i] = h;
|
||||||
|
}
|
||||||
|
|
||||||
void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
|
void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
|
||||||
inputs_.push_back(h);
|
inputs_.push_back(h);
|
||||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||||
|
@ -52,7 +52,9 @@ class EagerOperation {
|
|||||||
MutableInputs() {
|
MutableInputs() {
|
||||||
return &inputs_;
|
return &inputs_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddInput(tensorflow::TensorHandle* h);
|
void AddInput(tensorflow::TensorHandle* h);
|
||||||
|
void UpdateInput(int i, tensorflow::TensorHandle* h);
|
||||||
void ConsumeInput(tensorflow::TensorHandle* h);
|
void ConsumeInput(tensorflow::TensorHandle* h);
|
||||||
|
|
||||||
const tensorflow::string& Name() const { return name_; }
|
const tensorflow::string& Name() const { return name_; }
|
||||||
|
@ -89,10 +89,9 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This function expects *handle to point to an existing tensor handle. The
|
// This function expects *handle to point to an existing tensor handle. The
|
||||||
// function will (maybe) update the *handle to be pointed to the newly copied
|
// function will update the *handle to be pointed to the existing input tensor
|
||||||
// tensor handle.
|
// handle or else the newly copied tensor handle. The existing handle will have
|
||||||
//
|
// a Ref added, vs the new handle has a Ref due to being newly constructed.
|
||||||
// The passed in *handle will be Unreffed if it is replaced.
|
|
||||||
//
|
//
|
||||||
// `op_device_name` is passed in explicitly because `op->device()` might be
|
// `op_device_name` is passed in explicitly because `op->device()` might be
|
||||||
// unset and we might have selected some specific device to run this op on.
|
// unset and we might have selected some specific device to run this op on.
|
||||||
@ -100,18 +99,25 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
|||||||
const string& op_device_name, int i,
|
const string& op_device_name, int i,
|
||||||
const Device* expected_input_device,
|
const Device* expected_input_device,
|
||||||
RunMetadata* run_metadata,
|
RunMetadata* run_metadata,
|
||||||
TensorHandle** handle) {
|
TensorHandle** result) {
|
||||||
|
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
||||||
EagerContext* ctx = op->EagerContext();
|
EagerContext* ctx = op->EagerContext();
|
||||||
Device* handle_device = (*handle)->device();
|
Device* handle_device = handle->device();
|
||||||
const Device* actual_device =
|
const Device* actual_device =
|
||||||
handle_device == nullptr ? ctx->HostCPU() : handle_device;
|
handle_device == nullptr ? ctx->HostCPU() : handle_device;
|
||||||
|
|
||||||
if (expected_input_device != actual_device) {
|
if (expected_input_device == actual_device) {
|
||||||
|
// No copy was done, so the result is just the original handle with a Ref
|
||||||
|
handle->Ref();
|
||||||
|
*result = handle;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
switch (ctx->GetDevicePlacementPolicy()) {
|
switch (ctx->GetDevicePlacementPolicy()) {
|
||||||
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
||||||
// TODO(xpan): See if we could bubble python related error up
|
// TODO(xpan): See if we could bubble python related error up
|
||||||
// to python level.
|
// to python level.
|
||||||
if ((*handle)->dtype == DT_INT32) {
|
if (handle->dtype == DT_INT32) {
|
||||||
// Note: enabling silent copies of int32 tensors to match behavior
|
// Note: enabling silent copies of int32 tensors to match behavior
|
||||||
// of graph mode.
|
// of graph mode.
|
||||||
break;
|
break;
|
||||||
@ -123,8 +129,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
|||||||
" cannot compute ",
|
" cannot compute ",
|
||||||
op->Name(), " as input #", i, " was expected to be on ",
|
op->Name(), " as input #", i, " was expected to be on ",
|
||||||
expected_input_device->name(), " but is actually on ",
|
expected_input_device->name(), " but is actually on ",
|
||||||
actual_device->name(), " (operation running on ", op_device_name,
|
actual_device->name(), " (operation running on ", op_device_name, ")",
|
||||||
")",
|
|
||||||
" Tensors can be copied explicitly using:"
|
" Tensors can be copied explicitly using:"
|
||||||
" `with tf.device(device_name): x = tf.identity(x)`"
|
" `with tf.device(device_name): x = tf.identity(x)`"
|
||||||
" or transparently copied by using"
|
" or transparently copied by using"
|
||||||
@ -132,10 +137,9 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
|||||||
" Copying tensors between devices may slow down your model");
|
" Copying tensors between devices may slow down your model");
|
||||||
case DEVICE_PLACEMENT_WARN:
|
case DEVICE_PLACEMENT_WARN:
|
||||||
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
||||||
<< " was expected to be on "
|
<< " was expected to be on " << expected_input_device->name()
|
||||||
<< expected_input_device->name() << " but is actually on "
|
<< " but is actually on " << actual_device->name()
|
||||||
<< actual_device->name() << " (operation running on "
|
<< " (operation running on " << op_device_name
|
||||||
<< op_device_name
|
|
||||||
<< "). This triggers a copy which can be a performance "
|
<< "). This triggers a copy which can be a performance "
|
||||||
"bottleneck.";
|
"bottleneck.";
|
||||||
break;
|
break;
|
||||||
@ -147,7 +151,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
|||||||
auto pre_time_nanos = Env::Default()->NowNanos();
|
auto pre_time_nanos = Env::Default()->NowNanos();
|
||||||
TensorHandle* result_handle = nullptr;
|
TensorHandle* result_handle = nullptr;
|
||||||
Status status = EagerCopyToDevice(
|
Status status = EagerCopyToDevice(
|
||||||
*handle, ctx, expected_input_device->name().c_str(), &result_handle);
|
handle, ctx, expected_input_device->name().c_str(), &result_handle);
|
||||||
if (run_metadata != nullptr) {
|
if (run_metadata != nullptr) {
|
||||||
auto* step_stats = run_metadata->mutable_step_stats();
|
auto* step_stats = run_metadata->mutable_step_stats();
|
||||||
MaybeInitializeStepStats(step_stats, ctx);
|
MaybeInitializeStepStats(step_stats, ctx);
|
||||||
@ -156,8 +160,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
|||||||
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
|
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
|
||||||
auto* node_stats = dev_stats->add_node_stats();
|
auto* node_stats = dev_stats->add_node_stats();
|
||||||
node_stats->set_node_name("_Send");
|
node_stats->set_node_name("_Send");
|
||||||
node_stats->set_all_start_micros(pre_time_nanos /
|
node_stats->set_all_start_micros(pre_time_nanos / EnvTime::kMicrosToNanos);
|
||||||
EnvTime::kMicrosToNanos);
|
|
||||||
node_stats->set_all_start_nanos(pre_time_nanos);
|
node_stats->set_all_start_nanos(pre_time_nanos);
|
||||||
int64 now_nanos = Env::Default()->NowNanos();
|
int64 now_nanos = Env::Default()->NowNanos();
|
||||||
node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
|
node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
|
||||||
@ -169,15 +172,14 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
|||||||
}
|
}
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
if (result_handle != nullptr) result_handle->Unref();
|
if (result_handle != nullptr) result_handle->Unref();
|
||||||
return errors::Internal(
|
return errors::Internal("Failed copying input tensor from ",
|
||||||
"Failed copying input tensor from ", actual_device->name(), " to ",
|
actual_device->name(), " to ",
|
||||||
expected_input_device->name(), " in order to run ", op->Name(), ": ",
|
expected_input_device->name(), " in order to run ",
|
||||||
status.error_message());
|
op->Name(), ": ", status.error_message());
|
||||||
}
|
}
|
||||||
|
|
||||||
(*handle)->Unref();
|
*result = result_handle;
|
||||||
*handle = result_handle;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,10 +197,12 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx,
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||||
const Device* expected_device = kernel->InputDevice(i);
|
const Device* expected_device = kernel->InputDevice(i);
|
||||||
|
TensorHandle* handle = nullptr;
|
||||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||||
op, op_device_name, i, expected_device, run_metadata,
|
op, op_device_name, i, expected_device, run_metadata, &handle));
|
||||||
&((*op->MutableInputs())[i])));
|
op->UpdateInput(i, handle);
|
||||||
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
// Unref handle since it has a ref as an input now
|
||||||
|
handle->Unref();
|
||||||
if (handle->dtype != kernel->input_type(i)) {
|
if (handle->dtype != kernel->input_type(i)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"cannot compute ", op->Name(), " as input #", i, "(zero-based)",
|
"cannot compute ", op->Name(), " as input #", i, "(zero-based)",
|
||||||
@ -503,9 +507,13 @@ Status EagerLocalExecute(EagerOperation* op,
|
|||||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||||
TensorHandle* input = op->Inputs()[i];
|
TensorHandle* input = op->Inputs()[i];
|
||||||
if (input->IsRemote()) {
|
if (input->IsRemote()) {
|
||||||
|
TensorHandle* handle = nullptr;
|
||||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||||
input, ctx, device == nullptr ? "" : device->name().c_str(),
|
input, ctx, device == nullptr ? "" : device->name().c_str(),
|
||||||
&(*op->MutableInputs())[i]));
|
&handle));
|
||||||
|
op->UpdateInput(i, handle);
|
||||||
|
// Unref handle since it has a ref as an input now
|
||||||
|
handle->Unref();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -763,9 +771,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
// Always copy to the remote CPU so that the actual device can be
|
// Always copy to the remote CPU so that the actual device can be
|
||||||
// correctly determined after the kernel is selected/instantiated, since
|
// correctly determined after the kernel is selected/instantiated, since
|
||||||
// the op might have its inputs on host memory.
|
// the op might have its inputs on host memory.
|
||||||
|
TensorHandle* handle = nullptr;
|
||||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||||
op, op->Device()->name(), i, remote_cpu_device,
|
op, op->Device()->name(), i, remote_cpu_device,
|
||||||
/* run_metadata= */ nullptr, &(*op->MutableInputs())[i]));
|
/* run_metadata= */ nullptr, &handle));
|
||||||
|
op->UpdateInput(i, handle);
|
||||||
|
// Unref handle since it has a ref as an input now
|
||||||
|
handle->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||||
|
@ -545,6 +545,7 @@ cuda_py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["memory_test.py"],
|
srcs = ["memory_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
|
":remote",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
@ -553,6 +554,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
shard_count = 4,
|
||||||
tags = [
|
tags = [
|
||||||
"optonly", # The test is too slow in non-opt mode
|
"optonly", # The test is too slow in non-opt mode
|
||||||
],
|
],
|
||||||
|
@ -24,6 +24,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -31,11 +32,14 @@ from tensorflow.python import keras
|
|||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.eager import remote
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.variables import Variable
|
from tensorflow.python.ops.variables import Variable
|
||||||
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
# memory_profiler might not be available in the OSS version of TensorFlow.
|
# memory_profiler might not be available in the OSS version of TensorFlow.
|
||||||
try:
|
try:
|
||||||
@ -55,12 +59,7 @@ class SingleLayerNet(keras.Model):
|
|||||||
return self.fc1(x)
|
return self.fc1(x)
|
||||||
|
|
||||||
|
|
||||||
class MemoryTest(test.TestCase):
|
def assert_no_leak(f, num_iters=100000, increase_threshold_absolute_mb=10):
|
||||||
|
|
||||||
def assertNotIncreasingMemory(self,
|
|
||||||
f,
|
|
||||||
num_iters=100000,
|
|
||||||
increase_threshold_absolute_mb=10):
|
|
||||||
"""Assert memory usage doesn't increase beyond given threshold for f."""
|
"""Assert memory usage doesn't increase beyond given threshold for f."""
|
||||||
|
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
@ -84,6 +83,9 @@ class MemoryTest(test.TestCase):
|
|||||||
"Maximum allowed increase: %f") % (initial, increase,
|
"Maximum allowed increase: %f") % (initial, increase,
|
||||||
increase_threshold_absolute_mb)
|
increase_threshold_absolute_mb)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTest(test.TestCase):
|
||||||
|
|
||||||
def testMemoryLeakAnonymousVariable(self):
|
def testMemoryLeakAnonymousVariable(self):
|
||||||
if memory_profiler is None:
|
if memory_profiler is None:
|
||||||
self.skipTest("memory_profiler required to run this test")
|
self.skipTest("memory_profiler required to run this test")
|
||||||
@ -92,7 +94,7 @@ class MemoryTest(test.TestCase):
|
|||||||
inputs = Variable(array_ops.zeros([32, 100], dtypes.float32))
|
inputs = Variable(array_ops.zeros([32, 100], dtypes.float32))
|
||||||
del inputs
|
del inputs
|
||||||
|
|
||||||
self.assertNotIncreasingMemory(f, num_iters=10000)
|
assert_no_leak(f, num_iters=10000)
|
||||||
|
|
||||||
def testMemoryLeakInSimpleModelForwardOnly(self):
|
def testMemoryLeakInSimpleModelForwardOnly(self):
|
||||||
if memory_profiler is None:
|
if memory_profiler is None:
|
||||||
@ -105,7 +107,7 @@ class MemoryTest(test.TestCase):
|
|||||||
with backprop.GradientTape():
|
with backprop.GradientTape():
|
||||||
net(inputs)
|
net(inputs)
|
||||||
|
|
||||||
self.assertNotIncreasingMemory(f)
|
assert_no_leak(f)
|
||||||
|
|
||||||
def testMemoryLeakInSimpleModelForwardAndBackward(self):
|
def testMemoryLeakInSimpleModelForwardAndBackward(self):
|
||||||
if memory_profiler is None:
|
if memory_profiler is None:
|
||||||
@ -122,7 +124,7 @@ class MemoryTest(test.TestCase):
|
|||||||
|
|
||||||
del tape
|
del tape
|
||||||
|
|
||||||
self.assertNotIncreasingMemory(f)
|
assert_no_leak(f)
|
||||||
|
|
||||||
def testMemoryLeakInFunction(self):
|
def testMemoryLeakInFunction(self):
|
||||||
if memory_profiler is None:
|
if memory_profiler is None:
|
||||||
@ -136,8 +138,39 @@ class MemoryTest(test.TestCase):
|
|||||||
|
|
||||||
graph(constant_op.constant(42))
|
graph(constant_op.constant(42))
|
||||||
|
|
||||||
self.assertNotIncreasingMemory(
|
assert_no_leak(f, num_iters=1000, increase_threshold_absolute_mb=30)
|
||||||
f, num_iters=1000, increase_threshold_absolute_mb=20)
|
|
||||||
|
|
||||||
|
class RemoteWorkerMemoryTest(test.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, method):
|
||||||
|
super(RemoteWorkerMemoryTest, self).__init__(method)
|
||||||
|
|
||||||
|
# used for remote worker tests
|
||||||
|
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||||
|
self._cached_server = server_lib.Server.create_local_server()
|
||||||
|
self._cached_server_target = self._cached_server.target[len("grpc://"):]
|
||||||
|
|
||||||
|
def testMemoryLeakInLocalCopy(self):
|
||||||
|
if memory_profiler is None:
|
||||||
|
self.skipTest("memory_profiler required to run this test")
|
||||||
|
|
||||||
|
remote.connect_to_remote_host(self._cached_server_target)
|
||||||
|
|
||||||
|
# Run a function locally with the input on a remote worker and ensure we
|
||||||
|
# do not leak a reference to the remote tensor.
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def local_func(i):
|
||||||
|
return i
|
||||||
|
|
||||||
|
def func():
|
||||||
|
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
|
||||||
|
x = array_ops.zeros([1000, 1000], dtypes.int32)
|
||||||
|
|
||||||
|
local_func(x)
|
||||||
|
|
||||||
|
assert_no_leak(func, num_iters=100, increase_threshold_absolute_mb=50)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user