Improve handling of eager tensors on virtual CPUs
PiperOrigin-RevId: 257009081
This commit is contained in:
parent
29c522fd3e
commit
b9f8ebd7dd
@ -1197,9 +1197,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, &tensor));
|
||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
||||
return TensorHandle::CreateLocalHandle(tensor, dst_cpu ? nullptr : dstd,
|
||||
ctx, result);
|
||||
return TensorHandle::CreateLocalHandle(tensor, dstd, ctx, result);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
@ -381,9 +381,10 @@ void TensorHandle::Poison(Status status) {
|
||||
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||
tensorflow::Tensor* output) {
|
||||
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
|
||||
bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
|
||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
||||
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
|
||||
bool is_same_device =
|
||||
(srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
|
||||
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
TF_RETURN_IF_ERROR(Tensor(&src));
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import execute as execute_lib
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -62,8 +63,21 @@ def current_device():
|
||||
return constant_op.constant(1.).device
|
||||
|
||||
|
||||
def configure_virtual_cpus():
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
# Set 2 virtual CPUs
|
||||
config.set_virtual_device_configuration(cpus[0], [
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration()
|
||||
])
|
||||
|
||||
|
||||
class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TFETest, self).setUp()
|
||||
configure_virtual_cpus()
|
||||
|
||||
def testContext(self):
|
||||
ctx = context.Context()
|
||||
self.assertTrue(ctx.executing_eagerly())
|
||||
@ -130,6 +144,13 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
cpu_stats.device)
|
||||
self.assertGreaterEqual(len(cpu_stats.node_stats), 1)
|
||||
|
||||
def testMultiCpuPlacement(self):
|
||||
with ops.device('cpu:1'):
|
||||
x = constant_op.constant(1.0)
|
||||
y = array_ops.identity(x)
|
||||
self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
|
||||
self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
|
||||
|
||||
@test_util.run_gpu_only
|
||||
def testShouldCopy(self):
|
||||
with ops.device('gpu:0'):
|
||||
@ -758,6 +779,10 @@ class SendRecvTest(test_util.TensorFlowTestCase):
|
||||
'recv_device', device_name,
|
||||
'client_terminated', False))[0]
|
||||
|
||||
def setUp(self):
|
||||
super(SendRecvTest, self).setUp()
|
||||
configure_virtual_cpus()
|
||||
|
||||
def testBasic(self):
|
||||
t0 = constant_op.constant(1.0)
|
||||
t1 = constant_op.constant(2.0)
|
||||
@ -789,6 +814,10 @@ class SendRecvTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class EagerTensorCacheTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(EagerTensorCacheTest, self).setUp()
|
||||
configure_virtual_cpus()
|
||||
|
||||
def testCacheSkipsTensorsTooLarge(self):
|
||||
cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
|
||||
cache.put('1', array_ops.zeros((2, 2)))
|
||||
|
Loading…
Reference in New Issue
Block a user