Improve handling of eager tensors on virtual CPUs
PiperOrigin-RevId: 257009081
This commit is contained in:
parent
29c522fd3e
commit
b9f8ebd7dd
@ -1197,9 +1197,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
|
|||||||
} else {
|
} else {
|
||||||
tensorflow::Tensor tensor;
|
tensorflow::Tensor tensor;
|
||||||
TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, &tensor));
|
TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, &tensor));
|
||||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
return TensorHandle::CreateLocalHandle(tensor, dstd, ctx, result);
|
||||||
return TensorHandle::CreateLocalHandle(tensor, dst_cpu ? nullptr : dstd,
|
|
||||||
ctx, result);
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -381,9 +381,10 @@ void TensorHandle::Poison(Status status) {
|
|||||||
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||||
tensorflow::Tensor* output) {
|
tensorflow::Tensor* output) {
|
||||||
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
|
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
|
||||||
bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
|
|
||||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
||||||
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
|
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
|
||||||
|
bool is_same_device =
|
||||||
|
(srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
|
||||||
|
|
||||||
const tensorflow::Tensor* src = nullptr;
|
const tensorflow::Tensor* src = nullptr;
|
||||||
TF_RETURN_IF_ERROR(Tensor(&src));
|
TF_RETURN_IF_ERROR(Tensor(&src));
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import core
|
|||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import execute as execute_lib
|
from tensorflow.python.eager import execute as execute_lib
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -62,8 +63,21 @@ def current_device():
|
|||||||
return constant_op.constant(1.).device
|
return constant_op.constant(1.).device
|
||||||
|
|
||||||
|
|
||||||
|
def configure_virtual_cpus():
|
||||||
|
cpus = config.list_physical_devices('CPU')
|
||||||
|
# Set 2 virtual CPUs
|
||||||
|
config.set_virtual_device_configuration(cpus[0], [
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class TFETest(test_util.TensorFlowTestCase):
|
class TFETest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TFETest, self).setUp()
|
||||||
|
configure_virtual_cpus()
|
||||||
|
|
||||||
def testContext(self):
|
def testContext(self):
|
||||||
ctx = context.Context()
|
ctx = context.Context()
|
||||||
self.assertTrue(ctx.executing_eagerly())
|
self.assertTrue(ctx.executing_eagerly())
|
||||||
@ -130,6 +144,13 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
cpu_stats.device)
|
cpu_stats.device)
|
||||||
self.assertGreaterEqual(len(cpu_stats.node_stats), 1)
|
self.assertGreaterEqual(len(cpu_stats.node_stats), 1)
|
||||||
|
|
||||||
|
def testMultiCpuPlacement(self):
|
||||||
|
with ops.device('cpu:1'):
|
||||||
|
x = constant_op.constant(1.0)
|
||||||
|
y = array_ops.identity(x)
|
||||||
|
self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
|
||||||
|
self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testShouldCopy(self):
|
def testShouldCopy(self):
|
||||||
with ops.device('gpu:0'):
|
with ops.device('gpu:0'):
|
||||||
@ -758,6 +779,10 @@ class SendRecvTest(test_util.TensorFlowTestCase):
|
|||||||
'recv_device', device_name,
|
'recv_device', device_name,
|
||||||
'client_terminated', False))[0]
|
'client_terminated', False))[0]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(SendRecvTest, self).setUp()
|
||||||
|
configure_virtual_cpus()
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
t0 = constant_op.constant(1.0)
|
t0 = constant_op.constant(1.0)
|
||||||
t1 = constant_op.constant(2.0)
|
t1 = constant_op.constant(2.0)
|
||||||
@ -789,6 +814,10 @@ class SendRecvTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class EagerTensorCacheTest(test_util.TensorFlowTestCase):
|
class EagerTensorCacheTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(EagerTensorCacheTest, self).setUp()
|
||||||
|
configure_virtual_cpus()
|
||||||
|
|
||||||
def testCacheSkipsTensorsTooLarge(self):
|
def testCacheSkipsTensorsTooLarge(self):
|
||||||
cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
|
cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
|
||||||
cache.put('1', array_ops.zeros((2, 2)))
|
cache.put('1', array_ops.zeros((2, 2)))
|
||||||
|
Loading…
Reference in New Issue
Block a user