Allow TFE_TensorHandleCopyToDevice to have the same device as src and
destination. It will reuse the same underlying buffer in those cases. PiperOrigin-RevId: 164909906
This commit is contained in:
parent
13eb3b90e9
commit
2173b5b0a5
tensorflow
@ -201,39 +201,25 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
|
||||
const bool src_cpu = IsCPU(srcd);
|
||||
bool is_same_device =
|
||||
(srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
|
||||
const bool dst_cpu = IsCPU(dstd);
|
||||
if (!src_cpu && !dst_cpu) {
|
||||
if (is_same_device) {
|
||||
return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
|
||||
}
|
||||
const bool src_cpu = IsCPU(srcd);
|
||||
if (src_cpu == dst_cpu) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"TFE_TensorHandleCopyToDevice requires either the source "
|
||||
"TFE_TensorHandle be on or the destination device be CPU (they "
|
||||
"are ",
|
||||
"TFE_TensorHandle be on or the destination device be on CPU "
|
||||
"or be the same (they are ",
|
||||
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Tensor* src = &(h->t);
|
||||
if (src_cpu && dst_cpu) {
|
||||
// There must be a better way, but for now redirect through proto to ensure
|
||||
// that the underlying buffers are not shared.
|
||||
tensorflow::TensorProto proto;
|
||||
src->AsProtoTensorContent(&proto);
|
||||
tensorflow::Tensor dst(src->dtype(), src->shape());
|
||||
if (!dst.FromProto(proto)) {
|
||||
TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
tensorflow::strings::StrCat(
|
||||
"error copying between TFE_TensorHandles on CPU. Consider filing "
|
||||
"a bug report at https://github.com/tensorflow/tensorflow/issues "
|
||||
"mentioning version: ",
|
||||
TF_Version(), " and ", __FILE__, ":", __LINE__)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(dst, nullptr);
|
||||
}
|
||||
if (src_cpu) {
|
||||
tensorflow::Tensor dst(
|
||||
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
||||
|
@ -54,10 +54,10 @@ extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
|
||||
|
||||
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
|
||||
// in the memory of the device name 'device_name'.
|
||||
//
|
||||
// Currently requires at least one of the source or destination devices to
|
||||
// be CPU (i.e., for the source or destination tensor to be placed in
|
||||
// host memory).
|
||||
// If source and destination are the same device, then this creates a new handle
|
||||
// that shares the underlying buffer. Otherwise, it currently requires at least
|
||||
// one of the source or destination devices to be CPU (i.e., for the source or
|
||||
// destination tensor to be placed in host memory).
|
||||
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
|
@ -170,14 +170,22 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
|
||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||
continue;
|
||||
}
|
||||
// Copy back to CPU
|
||||
TFE_TensorHandle* hcopy =
|
||||
TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get());
|
||||
// Copy from device to the same device.
|
||||
TFE_TensorHandle* hdevice2 =
|
||||
TFE_TensorHandleCopyToDevice(hdevice, ctx, name.c_str(), status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||
continue;
|
||||
}
|
||||
TFE_DeleteTensorHandle(hdevice);
|
||||
// Copy back to CPU
|
||||
TFE_TensorHandle* hcopy =
|
||||
TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||
continue;
|
||||
}
|
||||
TFE_DeleteTensorHandle(hdevice2);
|
||||
|
||||
// Ensure that the contents are the same!
|
||||
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
|
||||
|
@ -150,21 +150,15 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
|
||||
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
|
||||
c2g = cpu.as_gpu_tensor()
|
||||
# Exercise a copy from GPU to CPU, even though we ignore the value.
|
||||
_ = c2g.as_cpu_tensor()
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
# c2g is on GPU. Copying between GPU devices fails
|
||||
# (must redirect through CPU for now).
|
||||
# TODO(ashankar): Perhaps the function should not fail and instead
|
||||
# faciliate the copy through host memory?
|
||||
c2g.as_gpu_tensor()
|
||||
x = tensor.Tensor([[1., 2.], [3., 4.]])
|
||||
x = x.as_cpu_tensor()
|
||||
x = x.as_gpu_tensor()
|
||||
x = x.as_gpu_tensor()
|
||||
x = x.as_cpu_tensor()
|
||||
|
||||
# Invalid device
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
cpu.as_gpu_tensor(context.context().num_gpus() + 1)
|
||||
x.as_gpu_tensor(context.context().num_gpus() + 1)
|
||||
|
||||
def testNumpyForceCPU(self):
|
||||
if not context.context().num_gpus():
|
||||
@ -274,7 +268,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
product = execute.execute(
|
||||
'MatMul',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([[3]]), tensor.Tensor([[5]])],
|
||||
inputs=[tensor.Tensor([[3]]),
|
||||
tensor.Tensor([[5]])],
|
||||
attrs=('transpose_a', True, 'transpose_b', False, 'T',
|
||||
dtypes.int32.as_datatype_enum))[0]
|
||||
self.assertEqual([[15]], product.numpy())
|
||||
@ -475,8 +470,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
with context.device('gpu:0'):
|
||||
y = truncated_normal(shape)
|
||||
# Add would fail if x and y were not on the same device.
|
||||
execute.execute('Add', 1, inputs=[x, y],
|
||||
attrs=('T', x.dtype.as_datatype_enum))
|
||||
execute.execute(
|
||||
'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))
|
||||
|
||||
def testInvalidDevice(self):
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user