Copy resource device correctly in TFE_TensorHandleCopyToDevice
PiperOrigin-RevId: 281992454 Change-Id: I223f05688b2c873dcffa02a743fa3302bda80f16
This commit is contained in:
parent
28fa9a621d
commit
9589742428
@ -1091,9 +1091,8 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
EagerExecutor* executor, Device* dstd,
|
||||
TensorHandle** result) {
|
||||
TF_RETURN_IF_ERROR(executor->status());
|
||||
Device* resource_device = (h->dtype == DT_RESOURCE) ? dstd : nullptr;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
|
||||
ctx->CanonicalDevice(dstd), dstd, resource_device, h->dtype, ctx,
|
||||
ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype, ctx,
|
||||
result));
|
||||
|
||||
// Note that `h` may not be currently ready. However execution order will
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
@ -411,6 +412,21 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(
|
||||
np.array(memoryview(t)), np.array([0.0], dtype=np.float32))
|
||||
|
||||
def testResourceTensorCopy(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("GPU only")
|
||||
|
||||
with ops.device("GPU:0"):
|
||||
v = resource_variable_ops.ResourceVariable(1.)
|
||||
|
||||
read_handle_on_gpu = resource_variable_ops.read_variable_op(
|
||||
v.handle, dtypes.float32)
|
||||
handle_on_cpu = v.handle.cpu()
|
||||
read_handle_on_cpu = resource_variable_ops.read_variable_op(
|
||||
handle_on_cpu, dtypes.float32)
|
||||
|
||||
self.assertAllEqual(read_handle_on_cpu, read_handle_on_gpu)
|
||||
|
||||
|
||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@ -523,5 +539,6 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
ValueError, "non-rectangular Python sequence"):
|
||||
constant_op.constant(l)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user