Fix device placement logic in ConvertToEagerTensor.
The fix consists in always creating a host tensor: the inputs to ConvertToEagerTensor are host Python objects, so it makes sense that the created tensor should be a host tensor too. The user can control GPU copies by using tf.identity. PiperOrigin-RevId: 297174740 Change-Id: I01f2aa9be3eb29fd49c7d81823e044db292b2d7c
This commit is contained in:
parent
4a010cc04c
commit
97cdd4d16a
@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
|||||||
wholly_compiled_f = def_function.function(f)
|
wholly_compiled_f = def_function.function(f)
|
||||||
op_by_op_f = def_function.function(f, experimental_compile=False)
|
op_by_op_f = def_function.function(f, experimental_compile=False)
|
||||||
|
|
||||||
x = constant_op.constant([0.0, 2.0], name='data')
|
x = array_ops.identity([0.0, 2.0], name='data')
|
||||||
|
|
||||||
# When function is wholly compiled, all outputs will be on the
|
# When function is wholly compiled, all outputs will be on the
|
||||||
# device on which it is run.
|
# device on which it is run.
|
||||||
|
@ -45,7 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
|||||||
}
|
}
|
||||||
const auto& t = d->device_type();
|
const auto& t = d->device_type();
|
||||||
device_type_counts_[t]++;
|
device_type_counts_[t]++;
|
||||||
if (cpu_device_ == nullptr && t == "CPU") {
|
if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) {
|
||||||
cpu_device_ = d.get();
|
cpu_device_ = d.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -194,7 +194,8 @@ Device* DynamicDeviceMgr::HostCPU() const {
|
|||||||
}
|
}
|
||||||
cpu_device_ = nullptr;
|
cpu_device_ = nullptr;
|
||||||
for (const auto& pair : dynamic_devices_) {
|
for (const auto& pair : dynamic_devices_) {
|
||||||
if (pair.first->device_type() == DEVICE_CPU) {
|
if (pair.first->device_type() == DEVICE_CPU &&
|
||||||
|
pair.first->parsed_name().id == 0) {
|
||||||
cpu_device_ = pair.first;
|
cpu_device_ = pair.first;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -167,6 +167,16 @@ std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
|
|||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<string> DeviceTypesToString(
|
||||||
|
const PrioritizedDeviceTypeVector& types) {
|
||||||
|
std::vector<string> v;
|
||||||
|
v.reserve(types.size());
|
||||||
|
for (const auto& p : types) {
|
||||||
|
v.push_back(p.first.type_string());
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
// Selects the "best" device that both exists and is supported.
|
// Selects the "best" device that both exists and is supported.
|
||||||
//
|
//
|
||||||
// The `existing` argument specifies the available devices in the system, in
|
// The `existing` argument specifies the available devices in the system, in
|
||||||
@ -232,13 +242,17 @@ Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Could not satisfy device specification '", preferred,
|
"Could not satisfy device specification '", preferred,
|
||||||
"'. enable_soft_placement=", AllowSoftPlacement(),
|
"'. enable_soft_placement=", AllowSoftPlacement(),
|
||||||
". All available devices [",
|
". Supported device types [",
|
||||||
|
absl::StrJoin(DeviceTypesToString(supported), ", "),
|
||||||
|
"]. All available devices [",
|
||||||
absl::StrJoin(DevicesToString(existing), ", "), "].");
|
absl::StrJoin(DevicesToString(existing), ", "), "].");
|
||||||
}
|
}
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"No supported device found in available devices [",
|
"No supported device found in available devices [",
|
||||||
absl::StrJoin(DevicesToString(existing), ", "),
|
absl::StrJoin(DevicesToString(existing), ", "),
|
||||||
"]. enable_soft_placement=", AllowSoftPlacement(), ".");
|
"]. enable_soft_placement=", AllowSoftPlacement(),
|
||||||
|
". Supported devices types [",
|
||||||
|
absl::StrJoin(DeviceTypesToString(supported), ", "), "].");
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerContext::ResetClusterFLR(
|
void EagerContext::ResetClusterFLR(
|
||||||
|
@ -1356,14 +1356,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def forward(x, w, b):
|
def forward(x, w, b):
|
||||||
return x * w + b
|
return x * w + b
|
||||||
|
|
||||||
x = constant_op.constant([1.0], name="x_useless")
|
x = array_ops.identity([1.0], name="x_useless")
|
||||||
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
|
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
|
||||||
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
|
|
||||||
def replica_fn():
|
def replica_fn():
|
||||||
with backprop.GradientTape() as t:
|
with backprop.GradientTape() as t:
|
||||||
x = constant_op.constant([1.0], name="x")
|
x = array_ops.identity([1.0], name="x")
|
||||||
loss = concrete_forward(x, w._get(), b._get()) - [1.0]
|
loss = concrete_forward(x, w._get(), b._get()) - [1.0]
|
||||||
return t.gradient(loss, [w, b])
|
return t.gradient(loss, [w, b])
|
||||||
|
|
||||||
|
@ -688,9 +688,9 @@ class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase):
|
|||||||
|
|
||||||
def _testDeviceScope(self, distribution):
|
def _testDeviceScope(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
a = constant_op.constant(1.)
|
a = array_ops.identity(1.)
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
b = constant_op.constant(1.)
|
b = array_ops.identity(1.)
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
device = "/job:worker/replica:0/task:0/device:CPU:0"
|
device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
else:
|
else:
|
||||||
|
@ -395,7 +395,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testMultiCpuPlacement(self):
|
def testMultiCpuPlacement(self):
|
||||||
with ops.device('cpu:1'):
|
with ops.device('cpu:1'):
|
||||||
x = constant_op.constant(1.0)
|
x = array_ops.identity(1.0)
|
||||||
with ops.device('cpu:0'):
|
with ops.device('cpu:0'):
|
||||||
y = array_ops.identity(x)
|
y = array_ops.identity(x)
|
||||||
self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
|
self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1')
|
||||||
@ -1084,7 +1084,7 @@ class SendRecvTest(test_util.TensorFlowTestCase):
|
|||||||
def testLocalCrossDevice(self):
|
def testLocalCrossDevice(self):
|
||||||
gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0'
|
gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0'
|
||||||
with ops.device('GPU:0'):
|
with ops.device('GPU:0'):
|
||||||
t0 = constant_op.constant(1.0)
|
t0 = array_ops.identity(1.0)
|
||||||
self._send(t0, 't0', self.cpu_device)
|
self._send(t0, 't0', self.cpu_device)
|
||||||
with ops.device('cpu:0'):
|
with ops.device('cpu:0'):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
@ -1115,4 +1115,5 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
context.set_log_device_placement(True)
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,24 +18,26 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import remote
|
from tensorflow.python.eager import remote
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
class SoftDevicePlacementTest(test.TestCase):
|
class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(SoftDevicePlacementTest, self).setUp()
|
super(SoftDevicePlacementTest, self).setUp()
|
||||||
context._context = None
|
context._reset_context()
|
||||||
ops.enable_eager_execution_internal()
|
|
||||||
config.set_soft_device_placement(enabled=True)
|
config.set_soft_device_placement(enabled=True)
|
||||||
context.context().log_device_placement = True
|
context.context().log_device_placement = True
|
||||||
|
|
||||||
@ -90,13 +92,21 @@ class SoftDevicePlacementTest(test.TestCase):
|
|||||||
# We don't support nested device placement right now.
|
# We don't support nested device placement right now.
|
||||||
self.assertIn('GPU:0', c.device)
|
self.assertIn('GPU:0', c.device)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('float', 1.0, None),
|
||||||
|
('int32', [1], dtypes.int32),
|
||||||
|
('string', ['a'], None))
|
||||||
|
def testSoftPlacedCPUConstant(self, value, dtype):
|
||||||
|
with ops.device('GPU:0'):
|
||||||
|
a = constant_op.constant(value, dtype=dtype)
|
||||||
|
self.assertIn('CPU:0', a.device)
|
||||||
|
self.assertIn('CPU:0', a.backing_device)
|
||||||
|
|
||||||
class HardDevicePlacementTest(test.TestCase):
|
|
||||||
|
class HardDevicePlacementTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(HardDevicePlacementTest, self).setUp()
|
super(HardDevicePlacementTest, self).setUp()
|
||||||
context._context = None
|
context._reset_context()
|
||||||
ops.enable_eager_execution_internal()
|
|
||||||
config.set_soft_device_placement(enabled=False)
|
config.set_soft_device_placement(enabled=False)
|
||||||
context.context().log_device_placement = True
|
context.context().log_device_placement = True
|
||||||
self.assertEqual(config.get_soft_device_placement(), False)
|
self.assertEqual(config.get_soft_device_placement(), False)
|
||||||
@ -114,13 +124,27 @@ class HardDevicePlacementTest(test.TestCase):
|
|||||||
self.assertIn('GPU:0', y.device)
|
self.assertIn('GPU:0', y.device)
|
||||||
self.assertIn('GPU:0', y.backing_device)
|
self.assertIn('GPU:0', y.backing_device)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('float_cpu0', 'CPU:0', 1.0, None),
|
||||||
|
('int32_cpu0', 'CPU:0', [1], dtypes.int32),
|
||||||
|
('string_cpu0', 'CPU:0', ['a'], None),
|
||||||
|
('float_gpu0', 'GPU:0', 1.0, None),
|
||||||
|
('int32_gpu0', 'GPU:0', [1], dtypes.int32),
|
||||||
|
('string_gpu0', 'GPU:0', ['a'], None),
|
||||||
|
('float_gpu99', 'GPU:99', 1.0, None),
|
||||||
|
('int32_gpu99', 'GPU:99', [1], dtypes.int32),
|
||||||
|
('string_gpu99', 'GPU:99', ['a'], None))
|
||||||
|
def testHardPlacedCPUConstant(self, device, value, dtype):
|
||||||
|
with ops.device(device):
|
||||||
|
a = constant_op.constant(value, dtype=dtype)
|
||||||
|
self.assertIn('CPU:0', a.device)
|
||||||
|
self.assertIn('CPU:0', a.backing_device)
|
||||||
|
|
||||||
|
|
||||||
class ClusterPlacementTest(test.TestCase):
|
class ClusterPlacementTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(ClusterPlacementTest, self).setUp()
|
super(ClusterPlacementTest, self).setUp()
|
||||||
context._context = None
|
context._reset_context()
|
||||||
ops.enable_eager_execution_internal()
|
|
||||||
config.set_soft_device_placement(enabled=True)
|
config.set_soft_device_placement(enabled=True)
|
||||||
context.context().log_device_placement = True
|
context.context().log_device_placement = True
|
||||||
workers, _ = test_util.create_local_cluster(2, 0)
|
workers, _ = test_util.create_local_cluster(2, 0)
|
||||||
|
@ -1570,10 +1570,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def testColocateWithRespected(self):
|
def testColocateWithRespected(self):
|
||||||
# TODO(b/113291792): Use multiple CPUs instead of a GPU.
|
# TODO(b/113291792): Use multiple CPUs instead of a GPU.
|
||||||
with ops.device('cpu:0'):
|
with ops.device('cpu:0'):
|
||||||
x = constant_op.constant(1.0)
|
x = array_ops.identity(1.0)
|
||||||
|
|
||||||
with ops.device('gpu:0'):
|
with ops.device('gpu:0'):
|
||||||
y = constant_op.constant(1.0)
|
y = array_ops.identity(1.0)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def foo():
|
def foo():
|
||||||
@ -3239,9 +3239,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
|||||||
return b, a
|
return b, a
|
||||||
|
|
||||||
with ops.device('/device:CPU:0'):
|
with ops.device('/device:CPU:0'):
|
||||||
a = constant_op.constant(3.0)
|
a = array_ops.identity(3.0)
|
||||||
with ops.device('/device:GPU:0'):
|
with ops.device('/device:GPU:0'):
|
||||||
b = constant_op.constant(5.0)
|
b = array_ops.identity(5.0)
|
||||||
|
|
||||||
m1, m2 = func(a, b)
|
m1, m2 = func(a, b)
|
||||||
self.assertAllEqual(m1.numpy(), 5.0)
|
self.assertAllEqual(m1.numpy(), 5.0)
|
||||||
@ -3306,9 +3306,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
|||||||
devices = ['/device:CPU:0', '/device:GPU:0']
|
devices = ['/device:CPU:0', '/device:GPU:0']
|
||||||
for dev1, dev2 in itertools.product(devices, devices):
|
for dev1, dev2 in itertools.product(devices, devices):
|
||||||
with ops.device(dev1):
|
with ops.device(dev1):
|
||||||
a = constant_op.constant(1.0)
|
a = array_ops.identity(1.0)
|
||||||
with ops.device(dev2):
|
with ops.device(dev2):
|
||||||
b = constant_op.constant(10.0)
|
b = array_ops.identity(10.0)
|
||||||
|
|
||||||
ra, rb = func(a, b)
|
ra, rb = func(a, b)
|
||||||
self.assertEqual(ra.numpy(), 2.0)
|
self.assertEqual(ra.numpy(), 2.0)
|
||||||
@ -3469,13 +3469,13 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/device:CPU:0'):
|
with ops.device('/device:CPU:0'):
|
||||||
rc0 = resource_variable_ops.ResourceVariable(2.0)
|
rc0 = resource_variable_ops.ResourceVariable(2.0)
|
||||||
rc1 = resource_variable_ops.ResourceVariable(3.0)
|
rc1 = resource_variable_ops.ResourceVariable(3.0)
|
||||||
cc0 = constant_op.constant(5.0)
|
cc0 = array_ops.identity(5.0)
|
||||||
cc1 = constant_op.constant(7.0)
|
cc1 = array_ops.identity(7.0)
|
||||||
with ops.device('/device:GPU:0'):
|
with ops.device('/device:GPU:0'):
|
||||||
rg0 = resource_variable_ops.ResourceVariable(11.0)
|
rg0 = resource_variable_ops.ResourceVariable(11.0)
|
||||||
rg1 = resource_variable_ops.ResourceVariable(13.0)
|
rg1 = resource_variable_ops.ResourceVariable(13.0)
|
||||||
cg0 = constant_op.constant(17.0)
|
cg0 = array_ops.identity(17.0)
|
||||||
cg1 = constant_op.constant(19.0)
|
cg1 = array_ops.identity(19.0)
|
||||||
|
|
||||||
# Make sure tensors are on expected devices.
|
# Make sure tensors are on expected devices.
|
||||||
for tensor in [cc0, cc1]:
|
for tensor in [cc0, cc1]:
|
||||||
|
@ -278,39 +278,13 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
// We always generate CPU:0 tensors, but we may need to change the device
|
||||||
// memory. We approximate the same behavior for eager execution - keeping
|
// slightly, as for example from /job:localhost/... to /job:worker/...
|
||||||
// int32 tensors in host memory.
|
|
||||||
//
|
//
|
||||||
// We do so to preclude the need for callers into such kernels from having to
|
// Note that this is a shallow copy and will share the underlying buffer,
|
||||||
// explicitly place the int32 tensors in host memory. For example, without
|
// because we are copying to the same device.
|
||||||
// this, one needed:
|
|
||||||
//
|
|
||||||
// with tf.device('/gpu:0'):
|
|
||||||
// ...// code here
|
|
||||||
// with tf.device('/cpu:0'):
|
|
||||||
// shape = tf.constant(...)
|
|
||||||
// y = tf.random_uniform(shape)
|
|
||||||
//
|
|
||||||
// Without the CPU device block, tfe.ops.random_uniform would fail since the
|
|
||||||
// kernel expects the shape in host memory.
|
|
||||||
//
|
|
||||||
// With this support, we simplify the code:
|
|
||||||
//
|
|
||||||
// with tf.device('/gpu:0'):
|
|
||||||
// y = tf.random_uniform(...)
|
|
||||||
//
|
|
||||||
// The approximation is not exact there are GPU kernels which do not require
|
|
||||||
// host memory for int32 tensors. This will lead to a discrepancy between
|
|
||||||
// eager and graph execution.
|
|
||||||
//
|
|
||||||
// To support remote execution copy int32 tensors to another CPU device.
|
|
||||||
// TODO(ashankar): Fix this.
|
|
||||||
if (device_name != nullptr &&
|
if (device_name != nullptr &&
|
||||||
(TFE_TensorHandleDataType(handle.get()) != TF_INT32 ||
|
strstr(device_name, "/device:CPU:0") != nullptr) {
|
||||||
strstr(device_name, "/device:CPU:0") != nullptr)) {
|
|
||||||
// Note that this is a shallow copy and will share the underlying buffer
|
|
||||||
// if copying to the same device.
|
|
||||||
handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
|
handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
|
||||||
device_name, status.get()));
|
device_name, status.get()));
|
||||||
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
|
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
|
||||||
@ -318,6 +292,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We always enable implicit mirroring for constants. Without this, code
|
||||||
|
// written previously under the assumption that
|
||||||
|
//
|
||||||
|
// with tf.device('GPU:0'): x = tf.constant(1.0)
|
||||||
|
//
|
||||||
|
// will be placed in the GPU will suffer a non-trivial performance regression
|
||||||
|
// (measured at ~20% for certain benchmarks).
|
||||||
|
handle->handle->EnableImplicitMirroring();
|
||||||
|
|
||||||
return handle.release();
|
return handle.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,9 +281,9 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testStringTensorOnGPU(self):
|
def testStringTensorOnGPU(self):
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
with self.assertRaisesRegexp(
|
t = _create_tensor("test string")
|
||||||
RuntimeError, "Can't copy Tensor with type string to device"):
|
self.assertIn("CPU", t.device)
|
||||||
_create_tensor("test string")
|
self.assertIn("CPU", t.backing_device)
|
||||||
|
|
||||||
def testInvalidUTF8ProducesReasonableError(self):
|
def testInvalidUTF8ProducesReasonableError(self):
|
||||||
if sys.version_info[0] < 3:
|
if sys.version_info[0] < 3:
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_ops
|
from tensorflow.python.framework import test_ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -380,10 +381,10 @@ class DeviceTest(test.TestCase):
|
|||||||
with ops.device('/device:CPU:1'):
|
with ops.device('/device:CPU:1'):
|
||||||
b = constant_op.constant(1.0)
|
b = constant_op.constant(1.0)
|
||||||
self.evaluate(b)
|
self.evaluate(b)
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
|
with ops.device('/device:CPU:2'):
|
||||||
with ops.device('/device:CPU:2'):
|
c = constant_op.constant(1.0)
|
||||||
c = constant_op.constant(1.0)
|
self.evaluate(c)
|
||||||
self.evaluate(c)
|
self.assertIn('CPU:0', c.device)
|
||||||
|
|
||||||
# Ensure we can place ops on each of the device names
|
# Ensure we can place ops on each of the device names
|
||||||
for vcpu in vcpus:
|
for vcpu in vcpus:
|
||||||
@ -408,6 +409,7 @@ class DeviceTest(test.TestCase):
|
|||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testGpuNone(self):
|
def testGpuNone(self):
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
gpus = config.list_physical_devices('GPU')
|
gpus = config.list_physical_devices('GPU')
|
||||||
self.assertGreater(len(gpus), 0)
|
self.assertGreater(len(gpus), 0)
|
||||||
|
|
||||||
@ -427,14 +429,16 @@ class DeviceTest(test.TestCase):
|
|||||||
self.assertEqual(len(config.get_visible_devices('GPU')), 0)
|
self.assertEqual(len(config.get_visible_devices('GPU')), 0)
|
||||||
self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0)
|
self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
'Could not satisfy'):
|
||||||
with ops.device('/device:GPU:0'):
|
with ops.device('/device:GPU:0'):
|
||||||
a = constant_op.constant(1.0)
|
a = array_ops.identity(1.0)
|
||||||
self.evaluate(a)
|
self.evaluate(a)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
'Could not satisfy'):
|
||||||
with ops.device('/device:XLA_GPU:0'):
|
with ops.device('/device:XLA_GPU:0'):
|
||||||
a = constant_op.constant(1.0)
|
a = array_ops.identity(1.0)
|
||||||
self.evaluate(a)
|
self.evaluate(a)
|
||||||
|
|
||||||
# Modifying the visible devices is not supported
|
# Modifying the visible devices is not supported
|
||||||
@ -465,6 +469,7 @@ class DeviceTest(test.TestCase):
|
|||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testVirtualGpu(self):
|
def testVirtualGpu(self):
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
gpus = config.list_physical_devices('GPU')
|
gpus = config.list_physical_devices('GPU')
|
||||||
self.assertNotEqual(len(gpus), 0)
|
self.assertNotEqual(len(gpus), 0)
|
||||||
|
|
||||||
@ -479,12 +484,13 @@ class DeviceTest(test.TestCase):
|
|||||||
self.assertTrue(len(logical_gpus), len(gpus) + 1)
|
self.assertTrue(len(logical_gpus), len(gpus) + 1)
|
||||||
for i in range(0, len(logical_gpus)):
|
for i in range(0, len(logical_gpus)):
|
||||||
with ops.device('/device:GPU:' + str(i)):
|
with ops.device('/device:GPU:' + str(i)):
|
||||||
a = constant_op.constant(1.0)
|
a = array_ops.identity(1.0)
|
||||||
self.evaluate(a)
|
self.evaluate(a)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
'Could not satisfy'):
|
||||||
with ops.device('/device:GPU:' + str(len(logical_gpus))):
|
with ops.device('/device:GPU:' + str(len(logical_gpus))):
|
||||||
a = constant_op.constant(1.0)
|
a = array_ops.identity(1.0)
|
||||||
self.evaluate(a)
|
self.evaluate(a)
|
||||||
|
|
||||||
# Modifying the GPU configuration is not supported
|
# Modifying the GPU configuration is not supported
|
||||||
|
@ -224,6 +224,10 @@ def constant(value, dtype=None, shape=None, name="Const"):
|
|||||||
...
|
...
|
||||||
NotImplementedError: ...
|
NotImplementedError: ...
|
||||||
|
|
||||||
|
`tf.constant` will _always_ create CPU (host) tensors. In order to create
|
||||||
|
tensors on other devices, use `tf.identity`. (If the `value` is an eager
|
||||||
|
Tensor, however, the tensor will be returned unmodified as mentioned above.)
|
||||||
|
|
||||||
Related Ops:
|
Related Ops:
|
||||||
|
|
||||||
* `tf.convert_to_tensor` is similar but:
|
* `tf.convert_to_tensor` is similar but:
|
||||||
|
Loading…
Reference in New Issue
Block a user