Copy resource device correctly in TFE_TensorHandleCopyToDevice

PiperOrigin-RevId: 281992454
Change-Id: I223f05688b2c873dcffa02a743fa3302bda80f16
This commit is contained in:
Akshay Modi 2019-11-22 10:15:24 -08:00 committed by TensorFlower Gardener
parent 28fa9a621d
commit 9589742428
2 changed files with 18 additions and 2 deletions

View File

@ -1091,9 +1091,8 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* dstd,
TensorHandle** result) {
TF_RETURN_IF_ERROR(executor->status());
Device* resource_device = (h->dtype == DT_RESOURCE) ? dstd : nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
ctx->CanonicalDevice(dstd), dstd, resource_device, h->dtype, ctx,
ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype, ctx,
result));
// Note that `h` may not be currently ready. However execution order will

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@ -411,6 +412,21 @@ class TFETensorTest(test_util.TensorFlowTestCase):
self.assertAllEqual(
np.array(memoryview(t)), np.array([0.0], dtype=np.float32))
def testResourceTensorCopy(self):
if not test_util.is_gpu_available():
self.skipTest("GPU only")
with ops.device("GPU:0"):
v = resource_variable_ops.ResourceVariable(1.)
read_handle_on_gpu = resource_variable_ops.read_variable_op(
v.handle, dtypes.float32)
handle_on_cpu = v.handle.cpu()
read_handle_on_cpu = resource_variable_ops.read_variable_op(
handle_on_cpu, dtypes.float32)
self.assertAllEqual(read_handle_on_cpu, read_handle_on_gpu)
class TFETensorUtilTest(test_util.TensorFlowTestCase):
@ -523,5 +539,6 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
ValueError, "non-rectangular Python sequence"):
constant_op.constant(l)
if __name__ == "__main__":
test.main()