[NFC] Provide synchronous versions of CopyDeviceTensorToCPU and CopyCPUTensorToDevice

PiperOrigin-RevId: 307091950
Change-Id: I21eb8e775954d8227ae909180689e49b7a039149
This commit is contained in:
George Karpenkov 2020-04-17 12:24:00 -07:00 committed by TensorFlower Gardener
parent e5c6881c77
commit 18809492b3
14 changed files with 90 additions and 152 deletions

View File

@ -145,16 +145,9 @@ Status XlaCompileOnDemandOp::Compile(
attrs.set_on_host(true); attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp( TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Notification n; Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
Status status;
ctx->op_device_context()->CopyDeviceTensorToCPU(
&device_tensor, "ConstantArgument", &device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor, reinterpret_cast<Device*>(ctx->device()), &host_tensor);
[&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape " LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from " << device_tensor.shape().DebugString() << " from "

View File

@ -488,15 +488,8 @@ Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
mutex_lock lock(mu_); mutex_lock lock(mu_);
Allocator* allocator = GetAllocatorLocked(alloc_attrs); Allocator* allocator = GetAllocatorLocked(alloc_attrs);
Tensor copy(allocator, parsed.dtype(), parsed.shape()); Tensor copy(allocator, parsed.dtype(), parsed.shape());
Notification n; TF_RETURN_IF_ERROR(
device_context->CopyCPUTensorToDevice( device_context->CopyCPUTensorToDeviceSync(&parsed, this, &copy));
&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
n.Notify();
},
true /*sync_dst_compute*/);
n.WaitForNotification();
*tensor = copy; *tensor = copy;
} }
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor); VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);

View File

@ -82,9 +82,8 @@ class UnaryOpsCompositionTest : public OpsTestBase {
DeviceContext* device_context = DeviceContext* device_context =
device_->tensorflow_gpu_device_info()->default_context; device_->tensorflow_gpu_device_info()->default_context;
TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) { TF_CHECK_OK(device_context->CopyCPUTensorToDeviceSync(&input_on_host,
device_context->CopyCPUTensorToDevice(&input_on_host, device_, input, cb); device_, input));
}));
TF_ASSERT_OK(RunOpKernel()); TF_ASSERT_OK(RunOpKernel());
@ -94,27 +93,12 @@ class UnaryOpsCompositionTest : public OpsTestBase {
Tensor* output = GetOutput(0); Tensor* output = GetOutput(0);
Tensor output_on_host(cpu_allocator, output->dtype(), output->shape()); Tensor output_on_host(cpu_allocator, output->dtype(), output->shape());
TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) { TF_CHECK_OK(device_context->CopyDeviceTensorToCPUSync(
device_context->CopyDeviceTensorToCPU(output, "output 0", device_, output, "output 0", device_, &output_on_host));
&output_on_host, cb);
}));
test::ExpectClose(expected_tensor, output_on_host, /*atol=*/1e-5, test::ExpectClose(expected_tensor, output_on_host, /*atol=*/1e-5,
/*rtol=*/1e-5); /*rtol=*/1e-5);
} }
private:
template <typename CopyFnTy>
Status BlockingCopy(CopyFnTy copy_fn) {
Notification n;
Status status;
copy_fn([&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
}; };
TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) { TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) {

View File

@ -2295,6 +2295,7 @@ tf_cuda_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
"//third_party/eigen3", "//third_party/eigen3",
"//tensorflow/core/example:feature_util", "//tensorflow/core/example:feature_util",

View File

@ -94,25 +94,14 @@ class GPUDeviceTest : public ::testing::Test {
void CopyCPUToGPU(Tensor* cpu_tensor, Tensor* gpu_tensor, Device* device, void CopyCPUToGPU(Tensor* cpu_tensor, Tensor* gpu_tensor, Device* device,
DeviceContext* device_context) { DeviceContext* device_context) {
Notification note; TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync(cpu_tensor, device,
device_context->CopyCPUTensorToDevice(cpu_tensor, device, gpu_tensor, gpu_tensor));
[&note](const Status& s) {
TF_ASSERT_OK(s);
note.Notify();
});
note.WaitForNotification();
} }
void CopyGPUToCPU(Tensor* gpu_tensor, Tensor* cpu_tensor, Device* device, void CopyGPUToCPU(Tensor* gpu_tensor, Tensor* cpu_tensor, Device* device,
DeviceContext* device_context) { DeviceContext* device_context) {
Notification note; TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync(
device_context->CopyDeviceTensorToCPU(gpu_tensor, /*tensor_name=*/"", gpu_tensor, /*tensor_name=*/"", device, cpu_tensor));
device, cpu_tensor,
[&note](const Status& s) {
TF_ASSERT_OK(s);
note.Notify();
});
note.WaitForNotification();
} }
}; };

View File

@ -442,13 +442,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
Device* dev = instances_[broadcast_dev_id]->device_; Device* dev = instances_[broadcast_dev_id]->device_;
auto* dev_info = dev->tensorflow_gpu_device_info(); auto* dev_info = dev->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
dev_info->default_context->CopyDeviceTensorToCPU( TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
t, "" /*tensor_name*/, dev, &cpu_copy, t, "" /*tensor_name*/, dev, &cpu_copy));
[this, &notification](Status s) {
TF_CHECK_OK(s);
notification.Notify();
});
notification.WaitForNotification();
t = &cpu_copy; t = &cpu_copy;
} }
for (size_t i = 0; i < t->NumElements(); ++i) { for (size_t i = 0; i < t->NumElements(); ++i) {
@ -473,17 +468,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
if (device_type_ == DEVICE_CPU) { if (device_type_ == DEVICE_CPU) {
CHECK(actual.CopyFrom(*inst, inst->shape())); CHECK(actual.CopyFrom(*inst, inst->shape()));
} else if (device_type_ == DEVICE_GPU) { } else if (device_type_ == DEVICE_GPU) {
Notification notification;
Device* dev = instances_[di]->device_; Device* dev = instances_[di]->device_;
auto* dev_info = dev->tensorflow_gpu_device_info(); auto* dev_info = dev->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
dev_info->default_context->CopyDeviceTensorToCPU( TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
inst, "" /*tensor_name*/, dev, &actual, inst, "" /*tensor_name*/, dev, &actual));
[this, &notification](Status s) {
TF_CHECK_OK(s);
notification.Notify();
});
notification.WaitForNotification();
} }
for (int i = 0; i < tensor_len; ++i) { for (int i = 0; i < tensor_len; ++i) {
switch (dtype) { switch (dtype) {
@ -623,12 +612,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
Notification notification; Notification notification;
auto* dev_info = device_->tensorflow_gpu_device_info(); auto* dev_info = device_->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
dev_info->default_context->CopyCPUTensorToDevice( TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
&cpu_tensor, device_, &tensor_, [&notification](Status s) { &cpu_tensor, device_, &tensor_));
TF_CHECK_OK(s);
notification.Notify();
});
notification.WaitForNotification();
} else { } else {
LOG(FATAL) << "Unsupported device_type " << device_type_; LOG(FATAL) << "Unsupported device_type " << device_type_;
} }

View File

@ -157,17 +157,11 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
DeviceContext* device_context = DeviceContext* device_context =
gpu_device_->tensorflow_gpu_device_info()->default_context; gpu_device_->tensorflow_gpu_device_info()->default_context;
Notification n;
Status status;
Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape()); Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape());
device_context->CopyDeviceTensorToCPU(&device_tensor, "", gpu_device_, CHECK(device_context
&cpu_tensor, ->CopyDeviceTensorToCPUSync(&device_tensor, "", gpu_device_,
[&n, &status](const Status& s) { &cpu_tensor)
status = s; .ok());
n.Notify();
});
n.WaitForNotification();
CHECK(status.ok());
return cpu_tensor; return cpu_tensor;
#else #else
CHECK(false); CHECK(false);
@ -181,18 +175,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
DeviceContext* device_context = DeviceContext* device_context =
gpu_device_->tensorflow_gpu_device_info()->default_context; gpu_device_->tensorflow_gpu_device_info()->default_context;
Notification n;
Status status;
Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(), Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(),
cpu_tensor.shape(), {}); cpu_tensor.shape(), {});
device_context->CopyCPUTensorToDevice(&cpu_tensor, gpu_device_, CHECK(device_context
&device_tensor, ->CopyCPUTensorToDeviceSync(&cpu_tensor, gpu_device_,
[&n, &status](const Status& s) { &device_tensor)
status = s; .ok());
n.Notify();
});
n.WaitForNotification();
CHECK(status.ok());
return device_tensor; return device_tensor;
#else #else
CHECK(false); CHECK(false);

View File

@ -254,14 +254,9 @@ string RingAlg::TensorDebugString(const Tensor& tensor) {
col_ctx_->op_ctx->device()->tensorflow_gpu_device_info(); col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
if (gpu_device_info) { if (gpu_device_info) {
Tensor cpu_tensor(tensor.dtype(), tensor.shape()); Tensor cpu_tensor(tensor.dtype(), tensor.shape());
Notification note; Status st = gpu_device_info->default_context->CopyDeviceTensorToCPUSync(
gpu_device_info->default_context->CopyDeviceTensorToCPU( &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
&tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor, DCHECK(st.ok());
[&note](const Status& s) {
DCHECK(s.ok());
note.Notify();
});
note.WaitForNotification();
return cpu_tensor.SummarizeValue(64); return cpu_tensor.SummarizeValue(64);
} else { } else {
return tensor.SummarizeValue(64); return tensor.SummarizeValue(64);

View File

@ -310,16 +310,13 @@ class RingGathererTest : public ::testing::Test {
CHECK(actual.CopyFrom(*inst, inst->shape())); CHECK(actual.CopyFrom(*inst, inst->shape()));
VLOG(1) << "actual " << actual.SummarizeValue(100); VLOG(1) << "actual " << actual.SummarizeValue(100);
} else if (device_type_ == DEVICE_GPU) { } else if (device_type_ == DEVICE_GPU) {
Notification note;
Device* dev = instances_[di]->device_; Device* dev = instances_[di]->device_;
auto* dev_info = dev->tensorflow_gpu_device_info(); auto* dev_info = dev->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
dev_info->default_context->CopyDeviceTensorToCPU( CHECK(dev_info->default_context
inst, "" /*tensor_name*/, dev, &actual, [&note](const Status& s) { ->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
CHECK(s.ok()); &actual)
note.Notify(); .ok());
});
note.WaitForNotification();
} }
auto alias = actual.template unaligned_flat<T>(); auto alias = actual.template unaligned_flat<T>();
@ -433,13 +430,10 @@ class RingGathererTest : public ::testing::Test {
init_f(&cpu_tensor); init_f(&cpu_tensor);
auto* dev_info = device_->tensorflow_gpu_device_info(); auto* dev_info = device_->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
Notification note; CHECK(dev_info->default_context
dev_info->default_context->CopyCPUTensorToDevice( ->CopyCPUTensorToDeviceSync(&cpu_tensor, device_,
&cpu_tensor, device_, &input_tensor_, [&note](const Status& s) { &input_tensor_)
CHECK(s.ok()); .ok());
note.Notify();
});
note.WaitForNotification();
} else { } else {
LOG(FATAL) << "Unsupported device_type " << device_type_; LOG(FATAL) << "Unsupported device_type " << device_type_;
} }

View File

@ -331,16 +331,13 @@ class RingReducerTest : public ::testing::Test {
CHECK(actual.CopyFrom(*inst, inst->shape())); CHECK(actual.CopyFrom(*inst, inst->shape()));
VLOG(1) << "actual " << actual.SummarizeValue(100); VLOG(1) << "actual " << actual.SummarizeValue(100);
} else if (device_type_ == DEVICE_GPU) { } else if (device_type_ == DEVICE_GPU) {
Notification note;
Device* dev = instances_[di]->device_; Device* dev = instances_[di]->device_;
auto* dev_info = dev->tensorflow_gpu_device_info(); auto* dev_info = dev->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
dev_info->default_context->CopyDeviceTensorToCPU( CHECK(dev_info->default_context
inst, "" /*tensor_name*/, dev, &actual, [&note](const Status& s) { ->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
CHECK(s.ok()); &actual)
note.Notify(); .ok());
});
note.WaitForNotification();
} }
auto alias = actual.template unaligned_flat<T>(); auto alias = actual.template unaligned_flat<T>();
@ -458,13 +455,9 @@ class RingReducerTest : public ::testing::Test {
init_f(&cpu_tensor); init_f(&cpu_tensor);
auto* dev_info = device_->tensorflow_gpu_device_info(); auto* dev_info = device_->tensorflow_gpu_device_info();
CHECK(dev_info); CHECK(dev_info);
Notification note; CHECK(dev_info->default_context
dev_info->default_context->CopyCPUTensorToDevice( ->CopyCPUTensorToDeviceSync(&cpu_tensor, device_, &tensor_)
&cpu_tensor, device_, &tensor_, [&note](const Status& s) { .ok());
CHECK(s.ok());
note.Notify();
});
note.WaitForNotification();
} else { } else {
LOG(FATAL) << "Unsupported device_type " << device_type_; LOG(FATAL) << "Unsupported device_type " << device_type_;
} }

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/synchronization/notification.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
@ -33,6 +34,35 @@ DeviceBase::~DeviceBase() {
eigen_cpu_devices_.clear(); eigen_cpu_devices_.clear();
} }
Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
StringPiece tensor_name,
Device* device,
Tensor* cpu_tensor) {
absl::Notification n;
Status status;
CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
[&](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor) const {
absl::Notification n;
Status status;
CopyCPUTensorToDevice(cpu_tensor, device, device_tensor,
[&](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
const DeviceAttributes& DeviceBase::attributes() const { const DeviceAttributes& DeviceBase::attributes() const {
LOG(FATAL) << "Device does not implement attributes()"; LOG(FATAL) << "Device does not implement attributes()";
} }

View File

@ -84,6 +84,10 @@ class DeviceContext : public core::RefCounted {
done(errors::Internal("Unrecognized device type in CPU-to-device Copy")); done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
} }
// Same as CopyCPUTensorToDevice, but in a synchronous way.
Status CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor) const;
// Copies a tensor in this device. // Copies a tensor in this device.
virtual void CopyTensorInSameDevice(const Tensor* input_tensor, virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
Device* device, Tensor* output_tensor, Device* device, Tensor* output_tensor,
@ -100,6 +104,11 @@ class DeviceContext : public core::RefCounted {
done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
} }
// Same as `CopyDeviceTensorToCPU`, but blocks until the copy is done.
Status CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor);
// If possible, wait for all events on *stream to complete then execute func. // If possible, wait for all events on *stream to complete then execute func.
// A non-OK Status is returned otherwise. The stream argument should be the // A non-OK Status is returned otherwise. The stream argument should be the
// one provided by GpuDeviceInfo. This function is not applicable to devices // one provided by GpuDeviceInfo. This function is not applicable to devices

View File

@ -205,15 +205,10 @@ class NcclTestBase : public ::testing::Test {
VLOG(2) << "rank " << rank << " output " << output << " buf " VLOG(2) << "rank " << rank << " output " << output << " buf "
<< DMAHelper::base(output); << DMAHelper::base(output);
Tensor actual(DT_FLOAT, TensorShape({output_length})); Tensor actual(DT_FLOAT, TensorShape({output_length}));
Notification note;
Device* dev = instances_[rank]->device_; Device* dev = instances_[rank]->device_;
auto* dev_info = dev->tensorflow_gpu_device_info(); auto* dev_info = dev->tensorflow_gpu_device_info();
dev_info->default_context->CopyDeviceTensorToCPU( TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
output, /*tensor_name=*/"", dev, &actual, [&note](const Status& s) { output, /*tensor_name=*/"", dev, &actual));
TF_CHECK_OK(s);
note.Notify();
});
note.WaitForNotification();
VLOG(3) << "rank " << rank << " got output tensor " VLOG(3) << "rank " << rank << " got output tensor "
<< actual.DebugString(output_length); << actual.DebugString(output_length);
for (int i = 0; i < output_length; ++i) { for (int i = 0; i < output_length; ++i) {
@ -270,13 +265,8 @@ class NcclTestBase : public ::testing::Test {
VLOG(2) << "input tensor " << cpu_tensor.DebugString(); VLOG(2) << "input tensor " << cpu_tensor.DebugString();
} }
auto* dev_info = device_->tensorflow_gpu_device_info(); auto* dev_info = device_->tensorflow_gpu_device_info();
Notification note; TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
dev_info->default_context->CopyCPUTensorToDevice( &cpu_tensor, device_, &input_));
&cpu_tensor, device_, &input_, [&note](const Status& s) {
TF_CHECK_OK(s);
note.Notify();
});
note.WaitForNotification();
} }
void PrepareDeviceContext(OpKernelContext::Params* params) { void PrepareDeviceContext(OpKernelContext::Params* params) {

View File

@ -447,14 +447,8 @@ class WhileOp : public AsyncOpKernel {
Device* device = down_cast<Device*>(ctx_->device()); Device* device = down_cast<Device*>(ctx_->device());
DeviceContext* device_ctx = ctx_->op_device_context(); DeviceContext* device_ctx = ctx_->op_device_context();
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape()); cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
Notification done_copy; s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"",
device_ctx->CopyDeviceTensorToCPU( device, &cond_t);
&rets_[0], /*tensor_name=*/"", device, &cond_t,
[&done_copy, &s](const Status& status) {
s = status;
done_copy.Notify();
});
done_copy.WaitForNotification();
if (!s.ok()) { if (!s.ok()) {
return Finish(s); return Finish(s);
} }