From 97cdd4d16a81a349696f10451b7d564bfa99664f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Feb 2020 12:19:21 -0800 Subject: [PATCH] Fix device placement logic in ConvertToEagerTensor. The fix consists in always creating a host tensor: the inputs to ConvertToEagerTensor are host Python objects, so it makes sense that the created tensor should be a host tensor too. The user can control GPU copies by using tf.identity. PiperOrigin-RevId: 297174740 Change-Id: I01f2aa9be3eb29fd49c7d81823e044db292b2d7c --- tensorflow/compiler/tests/eager_test.py | 2 +- tensorflow/core/common_runtime/device_mgr.cc | 2 +- .../core/common_runtime/dynamic_device_mgr.cc | 3 +- .../core/common_runtime/eager/context.cc | 18 +++++++- .../distribute/mirrored_strategy_test.py | 4 +- .../python/distribute/strategy_test_lib.py | 4 +- tensorflow/python/eager/core_test.py | 5 ++- .../python/eager/device_placement_test.py | 40 +++++++++++++---- tensorflow/python/eager/function_test.py | 20 ++++----- tensorflow/python/eager/pywrap_tensor.cc | 45 ++++++------------- tensorflow/python/eager/tensor_test.py | 6 +-- tensorflow/python/framework/config_test.py | 28 +++++++----- tensorflow/python/framework/constant_op.py | 4 ++ 13 files changed, 107 insertions(+), 74 deletions(-) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index a03980f20ba..0ed81b7e9e5 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase): wholly_compiled_f = def_function.function(f) op_by_op_f = def_function.function(f, experimental_compile=False) - x = constant_op.constant([0.0, 2.0], name='data') + x = array_ops.identity([0.0, 2.0], name='data') # When function is wholly compiled, all outputs will be on the # device on which it is run. diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index b17278fb365..c7583c374f2 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -45,7 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector> devices) } const auto& t = d->device_type(); device_type_counts_[t]++; - if (cpu_device_ == nullptr && t == "CPU") { + if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) { cpu_device_ = d.get(); } } diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index f7e2e27e4ab..4bea08bb021 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -194,7 +194,8 @@ Device* DynamicDeviceMgr::HostCPU() const { } cpu_device_ = nullptr; for (const auto& pair : dynamic_devices_) { - if (pair.first->device_type() == DEVICE_CPU) { + if (pair.first->device_type() == DEVICE_CPU && + pair.first->parsed_name().id == 0) { cpu_device_ = pair.first; break; } diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index fb051a9a583..5b2035edf43 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -167,6 +167,16 @@ std::vector DevicesToString(const PrioritizedDeviceVector& devices) { return v; } +std::vector DeviceTypesToString( + const PrioritizedDeviceTypeVector& types) { + std::vector v; + v.reserve(types.size()); + for (const auto& p : types) { + v.push_back(p.first.type_string()); + } + return v; +} + // Selects the "best" device that both exists and is supported. // // The `existing` argument specifies the available devices in the system, in @@ -232,13 +242,17 @@ Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred, return errors::InvalidArgument( "Could not satisfy device specification '", preferred, "'. enable_soft_placement=", AllowSoftPlacement(), - ". All available devices [", + ". Supported device types [", + absl::StrJoin(DeviceTypesToString(supported), ", "), + "]. All available devices [", absl::StrJoin(DevicesToString(existing), ", "), "]."); } return errors::InvalidArgument( "No supported device found in available devices [", absl::StrJoin(DevicesToString(existing), ", "), - "]. enable_soft_placement=", AllowSoftPlacement(), "."); + "]. enable_soft_placement=", AllowSoftPlacement(), + ". Supported devices types [", + absl::StrJoin(DeviceTypesToString(supported), ", "), "]."); } void EagerContext::ResetClusterFLR( diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 9f4b07a3e75..0ab4018ce13 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -1356,14 +1356,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def forward(x, w, b): return x * w + b - x = constant_op.constant([1.0], name="x_useless") + x = array_ops.identity([1.0], name="x_useless") concrete_forward = forward.get_concrete_function(x, w._primary, b._primary) with distribution.scope(): def replica_fn(): with backprop.GradientTape() as t: - x = constant_op.constant([1.0], name="x") + x = array_ops.identity([1.0], name="x") loss = concrete_forward(x, w._get(), b._get()) - [1.0] return t.gradient(loss, [w, b]) diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index c889484ae68..00730959d4e 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -688,9 +688,9 @@ class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase): def _testDeviceScope(self, distribution): with distribution.scope(): - a = constant_op.constant(1.) + a = array_ops.identity(1.) with ops.device("/cpu:0"): - b = constant_op.constant(1.) + b = array_ops.identity(1.) if context.executing_eagerly(): device = "/job:worker/replica:0/task:0/device:CPU:0" else: diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 273777090c6..d0b21cb237c 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -395,7 +395,7 @@ class TFETest(test_util.TensorFlowTestCase): def testMultiCpuPlacement(self): with ops.device('cpu:1'): - x = constant_op.constant(1.0) + x = array_ops.identity(1.0) with ops.device('cpu:0'): y = array_ops.identity(x) self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1') @@ -1084,7 +1084,7 @@ class SendRecvTest(test_util.TensorFlowTestCase): def testLocalCrossDevice(self): gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0' with ops.device('GPU:0'): - t0 = constant_op.constant(1.0) + t0 = array_ops.identity(1.0) self._send(t0, 't0', self.cpu_device) with ops.device('cpu:0'): self.assertAllEqual( @@ -1115,4 +1115,5 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase): if __name__ == '__main__': + context.set_log_device_placement(True) test.main() diff --git a/tensorflow/python/eager/device_placement_test.py b/tensorflow/python/eager/device_placement_test.py index 32ca6d3a826..af6c68243b4 100644 --- a/tensorflow/python/eager/device_placement_test.py +++ b/tensorflow/python/eager/device_placement_test.py @@ -18,24 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -class SoftDevicePlacementTest(test.TestCase): +class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase): def setUp(self): super(SoftDevicePlacementTest, self).setUp() - context._context = None - ops.enable_eager_execution_internal() + context._reset_context() config.set_soft_device_placement(enabled=True) context.context().log_device_placement = True @@ -90,13 +92,21 @@ class SoftDevicePlacementTest(test.TestCase): # We don't support nested device placement right now. self.assertIn('GPU:0', c.device) + @parameterized.named_parameters(('float', 1.0, None), + ('int32', [1], dtypes.int32), + ('string', ['a'], None)) + def testSoftPlacedCPUConstant(self, value, dtype): + with ops.device('GPU:0'): + a = constant_op.constant(value, dtype=dtype) + self.assertIn('CPU:0', a.device) + self.assertIn('CPU:0', a.backing_device) -class HardDevicePlacementTest(test.TestCase): + +class HardDevicePlacementTest(test.TestCase, parameterized.TestCase): def setUp(self): super(HardDevicePlacementTest, self).setUp() - context._context = None - ops.enable_eager_execution_internal() + context._reset_context() config.set_soft_device_placement(enabled=False) context.context().log_device_placement = True self.assertEqual(config.get_soft_device_placement(), False) @@ -114,13 +124,27 @@ class HardDevicePlacementTest(test.TestCase): self.assertIn('GPU:0', y.device) self.assertIn('GPU:0', y.backing_device) + @parameterized.named_parameters(('float_cpu0', 'CPU:0', 1.0, None), + ('int32_cpu0', 'CPU:0', [1], dtypes.int32), + ('string_cpu0', 'CPU:0', ['a'], None), + ('float_gpu0', 'GPU:0', 1.0, None), + ('int32_gpu0', 'GPU:0', [1], dtypes.int32), + ('string_gpu0', 'GPU:0', ['a'], None), + ('float_gpu99', 'GPU:99', 1.0, None), + ('int32_gpu99', 'GPU:99', [1], dtypes.int32), + ('string_gpu99', 'GPU:99', ['a'], None)) + def testHardPlacedCPUConstant(self, device, value, dtype): + with ops.device(device): + a = constant_op.constant(value, dtype=dtype) + self.assertIn('CPU:0', a.device) + self.assertIn('CPU:0', a.backing_device) + class ClusterPlacementTest(test.TestCase): def setUp(self): super(ClusterPlacementTest, self).setUp() - context._context = None - ops.enable_eager_execution_internal() + context._reset_context() config.set_soft_device_placement(enabled=True) context.context().log_device_placement = True workers, _ = test_util.create_local_cluster(2, 0) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 0a34d4a3852..7b599a995e2 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1570,10 +1570,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def testColocateWithRespected(self): # TODO(b/113291792): Use multiple CPUs instead of a GPU. with ops.device('cpu:0'): - x = constant_op.constant(1.0) + x = array_ops.identity(1.0) with ops.device('gpu:0'): - y = constant_op.constant(1.0) + y = array_ops.identity(1.0) @def_function.function def foo(): @@ -3239,9 +3239,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): return b, a with ops.device('/device:CPU:0'): - a = constant_op.constant(3.0) + a = array_ops.identity(3.0) with ops.device('/device:GPU:0'): - b = constant_op.constant(5.0) + b = array_ops.identity(5.0) m1, m2 = func(a, b) self.assertAllEqual(m1.numpy(), 5.0) @@ -3306,9 +3306,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): devices = ['/device:CPU:0', '/device:GPU:0'] for dev1, dev2 in itertools.product(devices, devices): with ops.device(dev1): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) with ops.device(dev2): - b = constant_op.constant(10.0) + b = array_ops.identity(10.0) ra, rb = func(a, b) self.assertEqual(ra.numpy(), 2.0) @@ -3469,13 +3469,13 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): with ops.device('/device:CPU:0'): rc0 = resource_variable_ops.ResourceVariable(2.0) rc1 = resource_variable_ops.ResourceVariable(3.0) - cc0 = constant_op.constant(5.0) - cc1 = constant_op.constant(7.0) + cc0 = array_ops.identity(5.0) + cc1 = array_ops.identity(7.0) with ops.device('/device:GPU:0'): rg0 = resource_variable_ops.ResourceVariable(11.0) rg1 = resource_variable_ops.ResourceVariable(13.0) - cg0 = constant_op.constant(17.0) - cg1 = constant_op.constant(19.0) + cg0 = array_ops.identity(17.0) + cg1 = array_ops.identity(19.0) # Make sure tensors are on expected devices. for tensor in [cc0, cc1]: diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index b5c9bfb6824..f8e1fb568ac 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -278,39 +278,13 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, } } - // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host - // memory. We approximate the same behavior for eager execution - keeping - // int32 tensors in host memory. + // We always generate CPU:0 tensors, but we may need to change the device + // slightly, as for example from /job:localhost/... to /job:worker/... // - // We do so to preclude the need for callers into such kernels from having to - // explicitly place the int32 tensors in host memory. For example, without - // this, one needed: - // - // with tf.device('/gpu:0'): - // ...// code here - // with tf.device('/cpu:0'): - // shape = tf.constant(...) - // y = tf.random_uniform(shape) - // - // Without the CPU device block, tfe.ops.random_uniform would fail since the - // kernel expects the shape in host memory. - // - // With this support, we simplify the code: - // - // with tf.device('/gpu:0'): - // y = tf.random_uniform(...) - // - // The approximation is not exact there are GPU kernels which do not require - // host memory for int32 tensors. This will lead to a discrepancy between - // eager and graph execution. - // - // To support remote execution copy int32 tensors to another CPU device. - // TODO(ashankar): Fix this. + // Note that this is a shallow copy and will share the underlying buffer, + // because we are copying to the same device. if (device_name != nullptr && - (TFE_TensorHandleDataType(handle.get()) != TF_INT32 || - strstr(device_name, "/device:CPU:0") != nullptr)) { - // Note that this is a shallow copy and will share the underlying buffer - // if copying to the same device. + strstr(device_name, "/device:CPU:0") != nullptr) { handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx, device_name, status.get())); if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) { @@ -318,6 +292,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, } } + // We always enable implicit mirroring for constants. Without this, code + // written previously under the assumption that + // + // with tf.device('GPU:0'): x = tf.constant(1.0) + // + // will be placed in the GPU will suffer a non-trivial performance regression + // (measured at ~20% for certain benchmarks). + handle->handle->EnableImplicitMirroring(); + return handle.release(); } diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index dd1f049cdcc..fe4a7933a32 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -281,9 +281,9 @@ class TFETensorTest(test_util.TensorFlowTestCase): @test_util.run_gpu_only def testStringTensorOnGPU(self): with ops.device("/device:GPU:0"): - with self.assertRaisesRegexp( - RuntimeError, "Can't copy Tensor with type string to device"): - _create_tensor("test string") + t = _create_tensor("test string") + self.assertIn("CPU", t.device) + self.assertIn("CPU", t.backing_device) def testInvalidUTF8ProducesReasonableError(self): if sys.version_info[0] < 3: diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index 72612a21cbf..2ef7d737d73 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -380,10 +381,10 @@ class DeviceTest(test.TestCase): with ops.device('/device:CPU:1'): b = constant_op.constant(1.0) self.evaluate(b) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): - with ops.device('/device:CPU:2'): - c = constant_op.constant(1.0) - self.evaluate(c) + with ops.device('/device:CPU:2'): + c = constant_op.constant(1.0) + self.evaluate(c) + self.assertIn('CPU:0', c.device) # Ensure we can place ops on each of the device names for vcpu in vcpus: @@ -408,6 +409,7 @@ class DeviceTest(test.TestCase): @test_util.run_gpu_only @reset_eager def testGpuNone(self): + config.set_soft_device_placement(False) gpus = config.list_physical_devices('GPU') self.assertGreater(len(gpus), 0) @@ -427,14 +429,16 @@ class DeviceTest(test.TestCase): self.assertEqual(len(config.get_visible_devices('GPU')), 0) self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'Could not satisfy'): with ops.device('/device:GPU:0'): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'Could not satisfy'): with ops.device('/device:XLA_GPU:0'): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) # Modifying the visible devices is not supported @@ -465,6 +469,7 @@ class DeviceTest(test.TestCase): @test_util.run_gpu_only @reset_eager def testVirtualGpu(self): + config.set_soft_device_placement(False) gpus = config.list_physical_devices('GPU') self.assertNotEqual(len(gpus), 0) @@ -479,12 +484,13 @@ class DeviceTest(test.TestCase): self.assertTrue(len(logical_gpus), len(gpus) + 1) for i in range(0, len(logical_gpus)): with ops.device('/device:GPU:' + str(i)): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'Could not satisfy'): with ops.device('/device:GPU:' + str(len(logical_gpus))): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) # Modifying the GPU configuration is not supported diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 4d9aa29ad60..9736bb8b78b 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -224,6 +224,10 @@ def constant(value, dtype=None, shape=None, name="Const"): ... NotImplementedError: ... + `tf.constant` will _always_ create CPU (host) tensors. In order to create + tensors on other devices, use `tf.identity`. (If the `value` is an eager + Tensor, however, the tensor will be returned unmodified as mentioned above.) + Related Ops: * `tf.convert_to_tensor` is similar but: