From 958974242816d551fa695ddeaf4e37fe78537ef6 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Fri, 22 Nov 2019 10:15:24 -0800 Subject: [PATCH] Copy resource device correctly in TFE_TensorHandleCopyToDevice PiperOrigin-RevId: 281992454 Change-Id: I223f05688b2c873dcffa02a743fa3302bda80f16 --- tensorflow/core/common_runtime/eager/execute.cc | 3 +-- tensorflow/python/eager/tensor_test.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 56ec9e66f8a..aa5af46caae 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -1091,9 +1091,8 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, EagerExecutor* executor, Device* dstd, TensorHandle** result) { TF_RETURN_IF_ERROR(executor->status()); - Device* resource_device = (h->dtype == DT_RESOURCE) ? dstd : nullptr; TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle( - ctx->CanonicalDevice(dstd), dstd, resource_device, h->dtype, ctx, + ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype, ctx, result)); // Note that `h` may not be currently ready. However execution order will diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 2d4ce0030c8..5f4b75b8bbd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -411,6 +412,21 @@ class TFETensorTest(test_util.TensorFlowTestCase): self.assertAllEqual( np.array(memoryview(t)), np.array([0.0], dtype=np.float32)) + def testResourceTensorCopy(self): + if not test_util.is_gpu_available(): + self.skipTest("GPU only") + + with ops.device("GPU:0"): + v = resource_variable_ops.ResourceVariable(1.) + + read_handle_on_gpu = resource_variable_ops.read_variable_op( + v.handle, dtypes.float32) + handle_on_cpu = v.handle.cpu() + read_handle_on_cpu = resource_variable_ops.read_variable_op( + handle_on_cpu, dtypes.float32) + + self.assertAllEqual(read_handle_on_cpu, read_handle_on_gpu) + class TFETensorUtilTest(test_util.TensorFlowTestCase): @@ -523,5 +539,6 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): ValueError, "non-rectangular Python sequence"): constant_op.constant(l) + if __name__ == "__main__": test.main()