Allow TFE_TensorHandleCopyToDevice to have the same device as src and
destination. It will reuse the same underlying buffer in those cases. PiperOrigin-RevId: 164909906
This commit is contained in:
parent
13eb3b90e9
commit
2173b5b0a5
tensorflow
@ -201,39 +201,25 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
|||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
|
|
||||||
tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
|
tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
|
||||||
const bool src_cpu = IsCPU(srcd);
|
bool is_same_device =
|
||||||
|
(srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
|
||||||
const bool dst_cpu = IsCPU(dstd);
|
const bool dst_cpu = IsCPU(dstd);
|
||||||
if (!src_cpu && !dst_cpu) {
|
if (is_same_device) {
|
||||||
|
return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
|
||||||
|
}
|
||||||
|
const bool src_cpu = IsCPU(srcd);
|
||||||
|
if (src_cpu == dst_cpu) {
|
||||||
TF_SetStatus(
|
TF_SetStatus(
|
||||||
status, TF_INVALID_ARGUMENT,
|
status, TF_INVALID_ARGUMENT,
|
||||||
tensorflow::strings::StrCat(
|
tensorflow::strings::StrCat(
|
||||||
"TFE_TensorHandleCopyToDevice requires either the source "
|
"TFE_TensorHandleCopyToDevice requires either the source "
|
||||||
"TFE_TensorHandle be on or the destination device be CPU (they "
|
"TFE_TensorHandle be on or the destination device be on CPU "
|
||||||
"are ",
|
"or be the same (they are ",
|
||||||
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
|
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
|
||||||
.c_str());
|
.c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::Tensor* src = &(h->t);
|
tensorflow::Tensor* src = &(h->t);
|
||||||
if (src_cpu && dst_cpu) {
|
|
||||||
// There must be a better way, but for now redirect through proto to ensure
|
|
||||||
// that the underlying buffers are not shared.
|
|
||||||
tensorflow::TensorProto proto;
|
|
||||||
src->AsProtoTensorContent(&proto);
|
|
||||||
tensorflow::Tensor dst(src->dtype(), src->shape());
|
|
||||||
if (!dst.FromProto(proto)) {
|
|
||||||
TF_SetStatus(
|
|
||||||
status, TF_INTERNAL,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"error copying between TFE_TensorHandles on CPU. Consider filing "
|
|
||||||
"a bug report at https://github.com/tensorflow/tensorflow/issues "
|
|
||||||
"mentioning version: ",
|
|
||||||
TF_Version(), " and ", __FILE__, ":", __LINE__)
|
|
||||||
.c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return new TFE_TensorHandle(dst, nullptr);
|
|
||||||
}
|
|
||||||
if (src_cpu) {
|
if (src_cpu) {
|
||||||
tensorflow::Tensor dst(
|
tensorflow::Tensor dst(
|
||||||
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
||||||
|
@ -54,10 +54,10 @@ extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
|
|||||||
|
|
||||||
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
|
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
|
||||||
// in the memory of the device name 'device_name'.
|
// in the memory of the device name 'device_name'.
|
||||||
//
|
// If source and destination are the same device, then this creates a new handle
|
||||||
// Currently requires at least one of the source or destination devices to
|
// that shares the underlying buffer. Otherwise, it currently requires at least
|
||||||
// be CPU (i.e., for the source or destination tensor to be placed in
|
// one of the source or destination devices to be CPU (i.e., for the source or
|
||||||
// host memory).
|
// destination tensor to be placed in host memory).
|
||||||
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||||
TFE_Context* ctx,
|
TFE_Context* ctx,
|
||||||
const char* device_name,
|
const char* device_name,
|
||||||
|
@ -170,14 +170,22 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
|
|||||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Copy back to CPU
|
// Copy from device to the same device.
|
||||||
TFE_TensorHandle* hcopy =
|
TFE_TensorHandle* hdevice2 =
|
||||||
TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get());
|
TFE_TensorHandleCopyToDevice(hdevice, ctx, name.c_str(), status.get());
|
||||||
if (TF_GetCode(status.get()) != TF_OK) {
|
if (TF_GetCode(status.get()) != TF_OK) {
|
||||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
TFE_DeleteTensorHandle(hdevice);
|
TFE_DeleteTensorHandle(hdevice);
|
||||||
|
// Copy back to CPU
|
||||||
|
TFE_TensorHandle* hcopy =
|
||||||
|
TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
|
||||||
|
if (TF_GetCode(status.get()) != TF_OK) {
|
||||||
|
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TFE_DeleteTensorHandle(hdevice2);
|
||||||
|
|
||||||
// Ensure that the contents are the same!
|
// Ensure that the contents are the same!
|
||||||
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
|
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
|
||||||
|
@ -150,21 +150,15 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest('No GPUs found')
|
self.skipTest('No GPUs found')
|
||||||
|
|
||||||
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
|
x = tensor.Tensor([[1., 2.], [3., 4.]])
|
||||||
c2g = cpu.as_gpu_tensor()
|
x = x.as_cpu_tensor()
|
||||||
# Exercise a copy from GPU to CPU, even though we ignore the value.
|
x = x.as_gpu_tensor()
|
||||||
_ = c2g.as_cpu_tensor()
|
x = x.as_gpu_tensor()
|
||||||
|
x = x.as_cpu_tensor()
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
|
||||||
# c2g is on GPU. Copying between GPU devices fails
|
|
||||||
# (must redirect through CPU for now).
|
|
||||||
# TODO(ashankar): Perhaps the function should not fail and instead
|
|
||||||
# faciliate the copy through host memory?
|
|
||||||
c2g.as_gpu_tensor()
|
|
||||||
|
|
||||||
# Invalid device
|
# Invalid device
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
cpu.as_gpu_tensor(context.context().num_gpus() + 1)
|
x.as_gpu_tensor(context.context().num_gpus() + 1)
|
||||||
|
|
||||||
def testNumpyForceCPU(self):
|
def testNumpyForceCPU(self):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
@ -274,7 +268,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
product = execute.execute(
|
product = execute.execute(
|
||||||
'MatMul',
|
'MatMul',
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
inputs=[tensor.Tensor([[3]]), tensor.Tensor([[5]])],
|
inputs=[tensor.Tensor([[3]]),
|
||||||
|
tensor.Tensor([[5]])],
|
||||||
attrs=('transpose_a', True, 'transpose_b', False, 'T',
|
attrs=('transpose_a', True, 'transpose_b', False, 'T',
|
||||||
dtypes.int32.as_datatype_enum))[0]
|
dtypes.int32.as_datatype_enum))[0]
|
||||||
self.assertEqual([[15]], product.numpy())
|
self.assertEqual([[15]], product.numpy())
|
||||||
@ -475,8 +470,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
with context.device('gpu:0'):
|
with context.device('gpu:0'):
|
||||||
y = truncated_normal(shape)
|
y = truncated_normal(shape)
|
||||||
# Add would fail if x and y were not on the same device.
|
# Add would fail if x and y were not on the same device.
|
||||||
execute.execute('Add', 1, inputs=[x, y],
|
execute.execute(
|
||||||
attrs=('T', x.dtype.as_datatype_enum))
|
'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))
|
||||||
|
|
||||||
def testInvalidDevice(self):
|
def testInvalidDevice(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user