Improve handling of eager tensors on virtual CPUs

PiperOrigin-RevId: 257009081
This commit is contained in:
Gaurav Jain 2019-07-08 10:36:51 -07:00 committed by TensorFlower Gardener
parent 29c522fd3e
commit b9f8ebd7dd
3 changed files with 32 additions and 4 deletions

View File

@ -1197,9 +1197,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
} else {
tensorflow::Tensor tensor;
TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, &tensor));
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
return TensorHandle::CreateLocalHandle(tensor, dst_cpu ? nullptr : dstd,
ctx, result);
return TensorHandle::CreateLocalHandle(tensor, dstd, ctx, result);
return Status::OK();
}
}

View File

@ -381,9 +381,10 @@ void TensorHandle::Poison(Status status) {
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
tensorflow::Tensor* output) {
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
bool is_same_device =
(srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
const tensorflow::Tensor* src = nullptr;
TF_RETURN_IF_ERROR(Tensor(&src));

View File

@ -31,6 +31,7 @@ from tensorflow.python.eager import core
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute as execute_lib
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -62,8 +63,21 @@ def current_device():
return constant_op.constant(1.).device
def configure_virtual_cpus():
cpus = config.list_physical_devices('CPU')
# Set 2 virtual CPUs
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
class TFETest(test_util.TensorFlowTestCase):
def setUp(self):
super(TFETest, self).setUp()
configure_virtual_cpus()
def testContext(self):
ctx = context.Context()
self.assertTrue(ctx.executing_eagerly())
@ -130,6 +144,13 @@ class TFETest(test_util.TensorFlowTestCase):
cpu_stats.device)
self.assertGreaterEqual(len(cpu_stats.node_stats), 1)
def testMultiCpuPlacement(self):
with ops.device('cpu:1'):
x = constant_op.constant(1.0)
y = array_ops.identity(x)
self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
@test_util.run_gpu_only
def testShouldCopy(self):
with ops.device('gpu:0'):
@ -758,6 +779,10 @@ class SendRecvTest(test_util.TensorFlowTestCase):
'recv_device', device_name,
'client_terminated', False))[0]
def setUp(self):
super(SendRecvTest, self).setUp()
configure_virtual_cpus()
def testBasic(self):
t0 = constant_op.constant(1.0)
t1 = constant_op.constant(2.0)
@ -789,6 +814,10 @@ class SendRecvTest(test_util.TensorFlowTestCase):
class EagerTensorCacheTest(test_util.TensorFlowTestCase):
def setUp(self):
super(EagerTensorCacheTest, self).setUp()
configure_virtual_cpus()
def testCacheSkipsTensorsTooLarge(self):
cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
cache.put('1', array_ops.zeros((2, 2)))