[NFC] Provide synchronous versions of CopyDeviceTensorToCPU and CopyCPUTensorToDevice
PiperOrigin-RevId: 307091950 Change-Id: I21eb8e775954d8227ae909180689e49b7a039149
This commit is contained in:
parent
e5c6881c77
commit
18809492b3
@ -145,16 +145,9 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
attrs.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
|
||||
Notification n;
|
||||
Status status;
|
||||
ctx->op_device_context()->CopyDeviceTensorToCPU(
|
||||
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
|
||||
&device_tensor, "ConstantArgument",
|
||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor,
|
||||
[&](Status s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Copying tensor of shape "
|
||||
<< device_tensor.shape().DebugString() << " from "
|
||||
|
@ -488,15 +488,8 @@ Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
|
||||
mutex_lock lock(mu_);
|
||||
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
||||
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
||||
Notification n;
|
||||
device_context->CopyCPUTensorToDevice(
|
||||
&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
},
|
||||
true /*sync_dst_compute*/);
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(
|
||||
device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©));
|
||||
*tensor = copy;
|
||||
}
|
||||
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
|
||||
|
@ -82,9 +82,8 @@ class UnaryOpsCompositionTest : public OpsTestBase {
|
||||
DeviceContext* device_context =
|
||||
device_->tensorflow_gpu_device_info()->default_context;
|
||||
|
||||
TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) {
|
||||
device_context->CopyCPUTensorToDevice(&input_on_host, device_, input, cb);
|
||||
}));
|
||||
TF_CHECK_OK(device_context->CopyCPUTensorToDeviceSync(&input_on_host,
|
||||
device_, input));
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
@ -94,27 +93,12 @@ class UnaryOpsCompositionTest : public OpsTestBase {
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor output_on_host(cpu_allocator, output->dtype(), output->shape());
|
||||
|
||||
TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) {
|
||||
device_context->CopyDeviceTensorToCPU(output, "output 0", device_,
|
||||
&output_on_host, cb);
|
||||
}));
|
||||
TF_CHECK_OK(device_context->CopyDeviceTensorToCPUSync(
|
||||
output, "output 0", device_, &output_on_host));
|
||||
|
||||
test::ExpectClose(expected_tensor, output_on_host, /*atol=*/1e-5,
|
||||
/*rtol=*/1e-5);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename CopyFnTy>
|
||||
Status BlockingCopy(CopyFnTy copy_fn) {
|
||||
Notification n;
|
||||
Status status;
|
||||
copy_fn([&](Status s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) {
|
||||
|
@ -2295,6 +2295,7 @@ tf_cuda_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core/example:feature_util",
|
||||
|
@ -94,25 +94,14 @@ class GPUDeviceTest : public ::testing::Test {
|
||||
|
||||
void CopyCPUToGPU(Tensor* cpu_tensor, Tensor* gpu_tensor, Device* device,
|
||||
DeviceContext* device_context) {
|
||||
Notification note;
|
||||
device_context->CopyCPUTensorToDevice(cpu_tensor, device, gpu_tensor,
|
||||
[¬e](const Status& s) {
|
||||
TF_ASSERT_OK(s);
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync(cpu_tensor, device,
|
||||
gpu_tensor));
|
||||
}
|
||||
|
||||
void CopyGPUToCPU(Tensor* gpu_tensor, Tensor* cpu_tensor, Device* device,
|
||||
DeviceContext* device_context) {
|
||||
Notification note;
|
||||
device_context->CopyDeviceTensorToCPU(gpu_tensor, /*tensor_name=*/"",
|
||||
device, cpu_tensor,
|
||||
[¬e](const Status& s) {
|
||||
TF_ASSERT_OK(s);
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync(
|
||||
gpu_tensor, /*tensor_name=*/"", device, cpu_tensor));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -442,13 +442,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
Device* dev = instances_[broadcast_dev_id]->device_;
|
||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
||||
t, "" /*tensor_name*/, dev, &cpu_copy,
|
||||
[this, ¬ification](Status s) {
|
||||
TF_CHECK_OK(s);
|
||||
notification.Notify();
|
||||
});
|
||||
notification.WaitForNotification();
|
||||
TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
|
||||
t, "" /*tensor_name*/, dev, &cpu_copy));
|
||||
t = &cpu_copy;
|
||||
}
|
||||
for (size_t i = 0; i < t->NumElements(); ++i) {
|
||||
@ -473,17 +468,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
if (device_type_ == DEVICE_CPU) {
|
||||
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
||||
} else if (device_type_ == DEVICE_GPU) {
|
||||
Notification notification;
|
||||
Device* dev = instances_[di]->device_;
|
||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
||||
inst, "" /*tensor_name*/, dev, &actual,
|
||||
[this, ¬ification](Status s) {
|
||||
TF_CHECK_OK(s);
|
||||
notification.Notify();
|
||||
});
|
||||
notification.WaitForNotification();
|
||||
TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
|
||||
inst, "" /*tensor_name*/, dev, &actual));
|
||||
}
|
||||
for (int i = 0; i < tensor_len; ++i) {
|
||||
switch (dtype) {
|
||||
@ -623,12 +612,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
Notification notification;
|
||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
dev_info->default_context->CopyCPUTensorToDevice(
|
||||
&cpu_tensor, device_, &tensor_, [¬ification](Status s) {
|
||||
TF_CHECK_OK(s);
|
||||
notification.Notify();
|
||||
});
|
||||
notification.WaitForNotification();
|
||||
TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
|
||||
&cpu_tensor, device_, &tensor_));
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
||||
}
|
||||
|
@ -157,17 +157,11 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
DeviceContext* device_context =
|
||||
gpu_device_->tensorflow_gpu_device_info()->default_context;
|
||||
|
||||
Notification n;
|
||||
Status status;
|
||||
Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape());
|
||||
device_context->CopyDeviceTensorToCPU(&device_tensor, "", gpu_device_,
|
||||
&cpu_tensor,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
CHECK(status.ok());
|
||||
CHECK(device_context
|
||||
->CopyDeviceTensorToCPUSync(&device_tensor, "", gpu_device_,
|
||||
&cpu_tensor)
|
||||
.ok());
|
||||
return cpu_tensor;
|
||||
#else
|
||||
CHECK(false);
|
||||
@ -181,18 +175,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
DeviceContext* device_context =
|
||||
gpu_device_->tensorflow_gpu_device_info()->default_context;
|
||||
|
||||
Notification n;
|
||||
Status status;
|
||||
Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(),
|
||||
cpu_tensor.shape(), {});
|
||||
device_context->CopyCPUTensorToDevice(&cpu_tensor, gpu_device_,
|
||||
&device_tensor,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
CHECK(status.ok());
|
||||
CHECK(device_context
|
||||
->CopyCPUTensorToDeviceSync(&cpu_tensor, gpu_device_,
|
||||
&device_tensor)
|
||||
.ok());
|
||||
return device_tensor;
|
||||
#else
|
||||
CHECK(false);
|
||||
|
@ -254,14 +254,9 @@ string RingAlg::TensorDebugString(const Tensor& tensor) {
|
||||
col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
if (gpu_device_info) {
|
||||
Tensor cpu_tensor(tensor.dtype(), tensor.shape());
|
||||
Notification note;
|
||||
gpu_device_info->default_context->CopyDeviceTensorToCPU(
|
||||
&tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
|
||||
[¬e](const Status& s) {
|
||||
DCHECK(s.ok());
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
Status st = gpu_device_info->default_context->CopyDeviceTensorToCPUSync(
|
||||
&tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
|
||||
DCHECK(st.ok());
|
||||
return cpu_tensor.SummarizeValue(64);
|
||||
} else {
|
||||
return tensor.SummarizeValue(64);
|
||||
|
@ -310,16 +310,13 @@ class RingGathererTest : public ::testing::Test {
|
||||
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
||||
VLOG(1) << "actual " << actual.SummarizeValue(100);
|
||||
} else if (device_type_ == DEVICE_GPU) {
|
||||
Notification note;
|
||||
Device* dev = instances_[di]->device_;
|
||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
||||
inst, "" /*tensor_name*/, dev, &actual, [¬e](const Status& s) {
|
||||
CHECK(s.ok());
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
CHECK(dev_info->default_context
|
||||
->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
|
||||
&actual)
|
||||
.ok());
|
||||
}
|
||||
|
||||
auto alias = actual.template unaligned_flat<T>();
|
||||
@ -433,13 +430,10 @@ class RingGathererTest : public ::testing::Test {
|
||||
init_f(&cpu_tensor);
|
||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
Notification note;
|
||||
dev_info->default_context->CopyCPUTensorToDevice(
|
||||
&cpu_tensor, device_, &input_tensor_, [¬e](const Status& s) {
|
||||
CHECK(s.ok());
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
CHECK(dev_info->default_context
|
||||
->CopyCPUTensorToDeviceSync(&cpu_tensor, device_,
|
||||
&input_tensor_)
|
||||
.ok());
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
||||
}
|
||||
|
@ -331,16 +331,13 @@ class RingReducerTest : public ::testing::Test {
|
||||
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
||||
VLOG(1) << "actual " << actual.SummarizeValue(100);
|
||||
} else if (device_type_ == DEVICE_GPU) {
|
||||
Notification note;
|
||||
Device* dev = instances_[di]->device_;
|
||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
||||
inst, "" /*tensor_name*/, dev, &actual, [¬e](const Status& s) {
|
||||
CHECK(s.ok());
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
CHECK(dev_info->default_context
|
||||
->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
|
||||
&actual)
|
||||
.ok());
|
||||
}
|
||||
|
||||
auto alias = actual.template unaligned_flat<T>();
|
||||
@ -458,13 +455,9 @@ class RingReducerTest : public ::testing::Test {
|
||||
init_f(&cpu_tensor);
|
||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||
CHECK(dev_info);
|
||||
Notification note;
|
||||
dev_info->default_context->CopyCPUTensorToDevice(
|
||||
&cpu_tensor, device_, &tensor_, [¬e](const Status& s) {
|
||||
CHECK(s.ok());
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
CHECK(dev_info->default_context
|
||||
->CopyCPUTensorToDeviceSync(&cpu_tensor, device_, &tensor_)
|
||||
.ok());
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -33,6 +34,35 @@ DeviceBase::~DeviceBase() {
|
||||
eigen_cpu_devices_.clear();
|
||||
}
|
||||
|
||||
Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
|
||||
StringPiece tensor_name,
|
||||
Device* device,
|
||||
Tensor* cpu_tensor) {
|
||||
absl::Notification n;
|
||||
Status status;
|
||||
CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
|
||||
[&](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
|
||||
Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor) const {
|
||||
absl::Notification n;
|
||||
Status status;
|
||||
CopyCPUTensorToDevice(cpu_tensor, device, device_tensor,
|
||||
[&](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
|
||||
const DeviceAttributes& DeviceBase::attributes() const {
|
||||
LOG(FATAL) << "Device does not implement attributes()";
|
||||
}
|
||||
|
@ -84,6 +84,10 @@ class DeviceContext : public core::RefCounted {
|
||||
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
||||
}
|
||||
|
||||
// Same as CopyCPUTensorToDevice, but in a synchronous way.
|
||||
Status CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor) const;
|
||||
|
||||
// Copies a tensor in this device.
|
||||
virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
|
||||
Device* device, Tensor* output_tensor,
|
||||
@ -100,6 +104,11 @@ class DeviceContext : public core::RefCounted {
|
||||
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
|
||||
}
|
||||
|
||||
// Same as `CopyDeviceTensorToCPU`, but blocks until the copy is done.
|
||||
Status CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor);
|
||||
|
||||
// If possible, wait for all events on *stream to complete then execute func.
|
||||
// A non-OK Status is returned otherwise. The stream argument should be the
|
||||
// one provided by GpuDeviceInfo. This function is not applicable to devices
|
||||
|
@ -205,15 +205,10 @@ class NcclTestBase : public ::testing::Test {
|
||||
VLOG(2) << "rank " << rank << " output " << output << " buf "
|
||||
<< DMAHelper::base(output);
|
||||
Tensor actual(DT_FLOAT, TensorShape({output_length}));
|
||||
Notification note;
|
||||
Device* dev = instances_[rank]->device_;
|
||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
||||
output, /*tensor_name=*/"", dev, &actual, [¬e](const Status& s) {
|
||||
TF_CHECK_OK(s);
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
|
||||
output, /*tensor_name=*/"", dev, &actual));
|
||||
VLOG(3) << "rank " << rank << " got output tensor "
|
||||
<< actual.DebugString(output_length);
|
||||
for (int i = 0; i < output_length; ++i) {
|
||||
@ -270,13 +265,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
VLOG(2) << "input tensor " << cpu_tensor.DebugString();
|
||||
}
|
||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||
Notification note;
|
||||
dev_info->default_context->CopyCPUTensorToDevice(
|
||||
&cpu_tensor, device_, &input_, [¬e](const Status& s) {
|
||||
TF_CHECK_OK(s);
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
|
||||
&cpu_tensor, device_, &input_));
|
||||
}
|
||||
|
||||
void PrepareDeviceContext(OpKernelContext::Params* params) {
|
||||
|
@ -447,14 +447,8 @@ class WhileOp : public AsyncOpKernel {
|
||||
Device* device = down_cast<Device*>(ctx_->device());
|
||||
DeviceContext* device_ctx = ctx_->op_device_context();
|
||||
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
|
||||
Notification done_copy;
|
||||
device_ctx->CopyDeviceTensorToCPU(
|
||||
&rets_[0], /*tensor_name=*/"", device, &cond_t,
|
||||
[&done_copy, &s](const Status& status) {
|
||||
s = status;
|
||||
done_copy.Notify();
|
||||
});
|
||||
done_copy.WaitForNotification();
|
||||
s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"",
|
||||
device, &cond_t);
|
||||
if (!s.ok()) {
|
||||
return Finish(s);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user