Fix device placement logic in ConvertToEagerTensor.

The fix consists in always creating a host tensor: the inputs to
ConvertToEagerTensor are host Python objects, so it makes sense that
the created tensor should be a host tensor too. The user can control
GPU copies by using tf.identity.

PiperOrigin-RevId: 297174740
Change-Id: I01f2aa9be3eb29fd49c7d81823e044db292b2d7c
This commit is contained in:
A. Unique TensorFlower 2020-02-25 12:19:21 -08:00 committed by TensorFlower Gardener
parent 4a010cc04c
commit 97cdd4d16a
13 changed files with 107 additions and 74 deletions

View File

@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
wholly_compiled_f = def_function.function(f) wholly_compiled_f = def_function.function(f)
op_by_op_f = def_function.function(f, experimental_compile=False) op_by_op_f = def_function.function(f, experimental_compile=False)
x = constant_op.constant([0.0, 2.0], name='data') x = array_ops.identity([0.0, 2.0], name='data')
# When function is wholly compiled, all outputs will be on the # When function is wholly compiled, all outputs will be on the
# device on which it is run. # device on which it is run.

View File

@ -45,7 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
} }
const auto& t = d->device_type(); const auto& t = d->device_type();
device_type_counts_[t]++; device_type_counts_[t]++;
if (cpu_device_ == nullptr && t == "CPU") { if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) {
cpu_device_ = d.get(); cpu_device_ = d.get();
} }
} }

View File

@ -194,7 +194,8 @@ Device* DynamicDeviceMgr::HostCPU() const {
} }
cpu_device_ = nullptr; cpu_device_ = nullptr;
for (const auto& pair : dynamic_devices_) { for (const auto& pair : dynamic_devices_) {
if (pair.first->device_type() == DEVICE_CPU) { if (pair.first->device_type() == DEVICE_CPU &&
pair.first->parsed_name().id == 0) {
cpu_device_ = pair.first; cpu_device_ = pair.first;
break; break;
} }

View File

@ -167,6 +167,16 @@ std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
return v; return v;
} }
std::vector<string> DeviceTypesToString(
const PrioritizedDeviceTypeVector& types) {
std::vector<string> v;
v.reserve(types.size());
for (const auto& p : types) {
v.push_back(p.first.type_string());
}
return v;
}
// Selects the "best" device that both exists and is supported. // Selects the "best" device that both exists and is supported.
// //
// The `existing` argument specifies the available devices in the system, in // The `existing` argument specifies the available devices in the system, in
@ -232,13 +242,17 @@ Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
return errors::InvalidArgument( return errors::InvalidArgument(
"Could not satisfy device specification '", preferred, "Could not satisfy device specification '", preferred,
"'. enable_soft_placement=", AllowSoftPlacement(), "'. enable_soft_placement=", AllowSoftPlacement(),
". All available devices [", ". Supported device types [",
absl::StrJoin(DeviceTypesToString(supported), ", "),
"]. All available devices [",
absl::StrJoin(DevicesToString(existing), ", "), "]."); absl::StrJoin(DevicesToString(existing), ", "), "].");
} }
return errors::InvalidArgument( return errors::InvalidArgument(
"No supported device found in available devices [", "No supported device found in available devices [",
absl::StrJoin(DevicesToString(existing), ", "), absl::StrJoin(DevicesToString(existing), ", "),
"]. enable_soft_placement=", AllowSoftPlacement(), "."); "]. enable_soft_placement=", AllowSoftPlacement(),
". Supported devices types [",
absl::StrJoin(DeviceTypesToString(supported), ", "), "].");
} }
void EagerContext::ResetClusterFLR( void EagerContext::ResetClusterFLR(

View File

@ -1356,14 +1356,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def forward(x, w, b): def forward(x, w, b):
return x * w + b return x * w + b
x = constant_op.constant([1.0], name="x_useless") x = array_ops.identity([1.0], name="x_useless")
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary) concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
with distribution.scope(): with distribution.scope():
def replica_fn(): def replica_fn():
with backprop.GradientTape() as t: with backprop.GradientTape() as t:
x = constant_op.constant([1.0], name="x") x = array_ops.identity([1.0], name="x")
loss = concrete_forward(x, w._get(), b._get()) - [1.0] loss = concrete_forward(x, w._get(), b._get()) - [1.0]
return t.gradient(loss, [w, b]) return t.gradient(loss, [w, b])

View File

@ -688,9 +688,9 @@ class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase):
def _testDeviceScope(self, distribution): def _testDeviceScope(self, distribution):
with distribution.scope(): with distribution.scope():
a = constant_op.constant(1.) a = array_ops.identity(1.)
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
b = constant_op.constant(1.) b = array_ops.identity(1.)
if context.executing_eagerly(): if context.executing_eagerly():
device = "/job:worker/replica:0/task:0/device:CPU:0" device = "/job:worker/replica:0/task:0/device:CPU:0"
else: else:

View File

@ -395,7 +395,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testMultiCpuPlacement(self): def testMultiCpuPlacement(self):
with ops.device('cpu:1'): with ops.device('cpu:1'):
x = constant_op.constant(1.0) x = array_ops.identity(1.0)
with ops.device('cpu:0'): with ops.device('cpu:0'):
y = array_ops.identity(x) y = array_ops.identity(x)
self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1') self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
@ -1084,7 +1084,7 @@ class SendRecvTest(test_util.TensorFlowTestCase):
def testLocalCrossDevice(self): def testLocalCrossDevice(self):
gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0' gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0'
with ops.device('GPU:0'): with ops.device('GPU:0'):
t0 = constant_op.constant(1.0) t0 = array_ops.identity(1.0)
self._send(t0, 't0', self.cpu_device) self._send(t0, 't0', self.cpu_device)
with ops.device('cpu:0'): with ops.device('cpu:0'):
self.assertAllEqual( self.assertAllEqual(
@ -1115,4 +1115,5 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
if __name__ == '__main__': if __name__ == '__main__':
context.set_log_device_placement(True)
test.main() test.main()

View File

@ -18,24 +18,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote from tensorflow.python.eager import remote
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
class SoftDevicePlacementTest(test.TestCase): class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(SoftDevicePlacementTest, self).setUp() super(SoftDevicePlacementTest, self).setUp()
context._context = None context._reset_context()
ops.enable_eager_execution_internal()
config.set_soft_device_placement(enabled=True) config.set_soft_device_placement(enabled=True)
context.context().log_device_placement = True context.context().log_device_placement = True
@ -90,13 +92,21 @@ class SoftDevicePlacementTest(test.TestCase):
# We don't support nested device placement right now. # We don't support nested device placement right now.
self.assertIn('GPU:0', c.device) self.assertIn('GPU:0', c.device)
@parameterized.named_parameters(('float', 1.0, None),
('int32', [1], dtypes.int32),
('string', ['a'], None))
def testSoftPlacedCPUConstant(self, value, dtype):
with ops.device('GPU:0'):
a = constant_op.constant(value, dtype=dtype)
self.assertIn('CPU:0', a.device)
self.assertIn('CPU:0', a.backing_device)
class HardDevicePlacementTest(test.TestCase):
class HardDevicePlacementTest(test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(HardDevicePlacementTest, self).setUp() super(HardDevicePlacementTest, self).setUp()
context._context = None context._reset_context()
ops.enable_eager_execution_internal()
config.set_soft_device_placement(enabled=False) config.set_soft_device_placement(enabled=False)
context.context().log_device_placement = True context.context().log_device_placement = True
self.assertEqual(config.get_soft_device_placement(), False) self.assertEqual(config.get_soft_device_placement(), False)
@ -114,13 +124,27 @@ class HardDevicePlacementTest(test.TestCase):
self.assertIn('GPU:0', y.device) self.assertIn('GPU:0', y.device)
self.assertIn('GPU:0', y.backing_device) self.assertIn('GPU:0', y.backing_device)
@parameterized.named_parameters(('float_cpu0', 'CPU:0', 1.0, None),
('int32_cpu0', 'CPU:0', [1], dtypes.int32),
('string_cpu0', 'CPU:0', ['a'], None),
('float_gpu0', 'GPU:0', 1.0, None),
('int32_gpu0', 'GPU:0', [1], dtypes.int32),
('string_gpu0', 'GPU:0', ['a'], None),
('float_gpu99', 'GPU:99', 1.0, None),
('int32_gpu99', 'GPU:99', [1], dtypes.int32),
('string_gpu99', 'GPU:99', ['a'], None))
def testHardPlacedCPUConstant(self, device, value, dtype):
with ops.device(device):
a = constant_op.constant(value, dtype=dtype)
self.assertIn('CPU:0', a.device)
self.assertIn('CPU:0', a.backing_device)
class ClusterPlacementTest(test.TestCase): class ClusterPlacementTest(test.TestCase):
def setUp(self): def setUp(self):
super(ClusterPlacementTest, self).setUp() super(ClusterPlacementTest, self).setUp()
context._context = None context._reset_context()
ops.enable_eager_execution_internal()
config.set_soft_device_placement(enabled=True) config.set_soft_device_placement(enabled=True)
context.context().log_device_placement = True context.context().log_device_placement = True
workers, _ = test_util.create_local_cluster(2, 0) workers, _ = test_util.create_local_cluster(2, 0)

View File

@ -1570,10 +1570,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def testColocateWithRespected(self): def testColocateWithRespected(self):
# TODO(b/113291792): Use multiple CPUs instead of a GPU. # TODO(b/113291792): Use multiple CPUs instead of a GPU.
with ops.device('cpu:0'): with ops.device('cpu:0'):
x = constant_op.constant(1.0) x = array_ops.identity(1.0)
with ops.device('gpu:0'): with ops.device('gpu:0'):
y = constant_op.constant(1.0) y = array_ops.identity(1.0)
@def_function.function @def_function.function
def foo(): def foo():
@ -3239,9 +3239,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
return b, a return b, a
with ops.device('/device:CPU:0'): with ops.device('/device:CPU:0'):
a = constant_op.constant(3.0) a = array_ops.identity(3.0)
with ops.device('/device:GPU:0'): with ops.device('/device:GPU:0'):
b = constant_op.constant(5.0) b = array_ops.identity(5.0)
m1, m2 = func(a, b) m1, m2 = func(a, b)
self.assertAllEqual(m1.numpy(), 5.0) self.assertAllEqual(m1.numpy(), 5.0)
@ -3306,9 +3306,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
devices = ['/device:CPU:0', '/device:GPU:0'] devices = ['/device:CPU:0', '/device:GPU:0']
for dev1, dev2 in itertools.product(devices, devices): for dev1, dev2 in itertools.product(devices, devices):
with ops.device(dev1): with ops.device(dev1):
a = constant_op.constant(1.0) a = array_ops.identity(1.0)
with ops.device(dev2): with ops.device(dev2):
b = constant_op.constant(10.0) b = array_ops.identity(10.0)
ra, rb = func(a, b) ra, rb = func(a, b)
self.assertEqual(ra.numpy(), 2.0) self.assertEqual(ra.numpy(), 2.0)
@ -3469,13 +3469,13 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
with ops.device('/device:CPU:0'): with ops.device('/device:CPU:0'):
rc0 = resource_variable_ops.ResourceVariable(2.0) rc0 = resource_variable_ops.ResourceVariable(2.0)
rc1 = resource_variable_ops.ResourceVariable(3.0) rc1 = resource_variable_ops.ResourceVariable(3.0)
cc0 = constant_op.constant(5.0) cc0 = array_ops.identity(5.0)
cc1 = constant_op.constant(7.0) cc1 = array_ops.identity(7.0)
with ops.device('/device:GPU:0'): with ops.device('/device:GPU:0'):
rg0 = resource_variable_ops.ResourceVariable(11.0) rg0 = resource_variable_ops.ResourceVariable(11.0)
rg1 = resource_variable_ops.ResourceVariable(13.0) rg1 = resource_variable_ops.ResourceVariable(13.0)
cg0 = constant_op.constant(17.0) cg0 = array_ops.identity(17.0)
cg1 = constant_op.constant(19.0) cg1 = array_ops.identity(19.0)
# Make sure tensors are on expected devices. # Make sure tensors are on expected devices.
for tensor in [cc0, cc1]: for tensor in [cc0, cc1]:

View File

@ -278,39 +278,13 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
} }
} }
// Almost all TensorFlow kernels for GPU devices keep int32 tensors in host // We always generate CPU:0 tensors, but we may need to change the device
// memory. We approximate the same behavior for eager execution - keeping // slightly, as for example from /job:localhost/... to /job:worker/...
// int32 tensors in host memory.
// //
// We do so to preclude the need for callers into such kernels from having to // Note that this is a shallow copy and will share the underlying buffer,
// explicitly place the int32 tensors in host memory. For example, without // because we are copying to the same device.
// this, one needed:
//
// with tf.device('/gpu:0'):
// ...// code here
// with tf.device('/cpu:0'):
// shape = tf.constant(...)
// y = tf.random_uniform(shape)
//
// Without the CPU device block, tfe.ops.random_uniform would fail since the
// kernel expects the shape in host memory.
//
// With this support, we simplify the code:
//
// with tf.device('/gpu:0'):
// y = tf.random_uniform(...)
//
// The approximation is not exact there are GPU kernels which do not require
// host memory for int32 tensors. This will lead to a discrepancy between
// eager and graph execution.
//
// To support remote execution copy int32 tensors to another CPU device.
// TODO(ashankar): Fix this.
if (device_name != nullptr && if (device_name != nullptr &&
(TFE_TensorHandleDataType(handle.get()) != TF_INT32 || strstr(device_name, "/device:CPU:0") != nullptr) {
strstr(device_name, "/device:CPU:0") != nullptr)) {
// Note that this is a shallow copy and will share the underlying buffer
// if copying to the same device.
handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx, handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
device_name, status.get())); device_name, status.get()));
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) { if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
@ -318,6 +292,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
} }
} }
// We always enable implicit mirroring for constants. Without this, code
// written previously under the assumption that
//
// with tf.device('GPU:0'): x = tf.constant(1.0)
//
// will be placed in the GPU will suffer a non-trivial performance regression
// (measured at ~20% for certain benchmarks).
handle->handle->EnableImplicitMirroring();
return handle.release(); return handle.release();
} }

View File

@ -281,9 +281,9 @@ class TFETensorTest(test_util.TensorFlowTestCase):
@test_util.run_gpu_only @test_util.run_gpu_only
def testStringTensorOnGPU(self): def testStringTensorOnGPU(self):
with ops.device("/device:GPU:0"): with ops.device("/device:GPU:0"):
with self.assertRaisesRegexp( t = _create_tensor("test string")
RuntimeError, "Can't copy Tensor with type string to device"): self.assertIn("CPU", t.device)
_create_tensor("test string") self.assertIn("CPU", t.backing_device)
def testInvalidUTF8ProducesReasonableError(self): def testInvalidUTF8ProducesReasonableError(self):
if sys.version_info[0] < 3: if sys.version_info[0] < 3:

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -380,10 +381,10 @@ class DeviceTest(test.TestCase):
with ops.device('/device:CPU:1'): with ops.device('/device:CPU:1'):
b = constant_op.constant(1.0) b = constant_op.constant(1.0)
self.evaluate(b) self.evaluate(b)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'): with ops.device('/device:CPU:2'):
with ops.device('/device:CPU:2'): c = constant_op.constant(1.0)
c = constant_op.constant(1.0) self.evaluate(c)
self.evaluate(c) self.assertIn('CPU:0', c.device)
# Ensure we can place ops on each of the device names # Ensure we can place ops on each of the device names
for vcpu in vcpus: for vcpu in vcpus:
@ -408,6 +409,7 @@ class DeviceTest(test.TestCase):
@test_util.run_gpu_only @test_util.run_gpu_only
@reset_eager @reset_eager
def testGpuNone(self): def testGpuNone(self):
config.set_soft_device_placement(False)
gpus = config.list_physical_devices('GPU') gpus = config.list_physical_devices('GPU')
self.assertGreater(len(gpus), 0) self.assertGreater(len(gpus), 0)
@ -427,14 +429,16 @@ class DeviceTest(test.TestCase):
self.assertEqual(len(config.get_visible_devices('GPU')), 0) self.assertEqual(len(config.get_visible_devices('GPU')), 0)
self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0) self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'): with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Could not satisfy'):
with ops.device('/device:GPU:0'): with ops.device('/device:GPU:0'):
a = constant_op.constant(1.0) a = array_ops.identity(1.0)
self.evaluate(a) self.evaluate(a)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'): with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Could not satisfy'):
with ops.device('/device:XLA_GPU:0'): with ops.device('/device:XLA_GPU:0'):
a = constant_op.constant(1.0) a = array_ops.identity(1.0)
self.evaluate(a) self.evaluate(a)
# Modifying the visible devices is not supported # Modifying the visible devices is not supported
@ -465,6 +469,7 @@ class DeviceTest(test.TestCase):
@test_util.run_gpu_only @test_util.run_gpu_only
@reset_eager @reset_eager
def testVirtualGpu(self): def testVirtualGpu(self):
config.set_soft_device_placement(False)
gpus = config.list_physical_devices('GPU') gpus = config.list_physical_devices('GPU')
self.assertNotEqual(len(gpus), 0) self.assertNotEqual(len(gpus), 0)
@ -479,12 +484,13 @@ class DeviceTest(test.TestCase):
self.assertTrue(len(logical_gpus), len(gpus) + 1) self.assertTrue(len(logical_gpus), len(gpus) + 1)
for i in range(0, len(logical_gpus)): for i in range(0, len(logical_gpus)):
with ops.device('/device:GPU:' + str(i)): with ops.device('/device:GPU:' + str(i)):
a = constant_op.constant(1.0) a = array_ops.identity(1.0)
self.evaluate(a) self.evaluate(a)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'): with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Could not satisfy'):
with ops.device('/device:GPU:' + str(len(logical_gpus))): with ops.device('/device:GPU:' + str(len(logical_gpus))):
a = constant_op.constant(1.0) a = array_ops.identity(1.0)
self.evaluate(a) self.evaluate(a)
# Modifying the GPU configuration is not supported # Modifying the GPU configuration is not supported

View File

@ -224,6 +224,10 @@ def constant(value, dtype=None, shape=None, name="Const"):
... ...
NotImplementedError: ... NotImplementedError: ...
`tf.constant` will _always_ create CPU (host) tensors. In order to create
tensors on other devices, use `tf.identity`. (If the `value` is an eager
Tensor, however, the tensor will be returned unmodified as mentioned above.)
Related Ops: Related Ops:
* `tf.convert_to_tensor` is similar but: * `tf.convert_to_tensor` is similar but: