[NFC] Provide synchronous versions of CopyDeviceTensorToCPU and CopyCPUTensorToDevice
PiperOrigin-RevId: 307091950 Change-Id: I21eb8e775954d8227ae909180689e49b7a039149
This commit is contained in:
parent
e5c6881c77
commit
18809492b3
@ -145,16 +145,9 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
attrs.set_on_host(true);
|
attrs.set_on_host(true);
|
||||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
|
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
|
||||||
Notification n;
|
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
|
||||||
Status status;
|
|
||||||
ctx->op_device_context()->CopyDeviceTensorToCPU(
|
|
||||||
&device_tensor, "ConstantArgument",
|
&device_tensor, "ConstantArgument",
|
||||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor,
|
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
|
||||||
[&](Status s) {
|
|
||||||
status = s;
|
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
n.WaitForNotification();
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Copying tensor of shape "
|
LOG(ERROR) << "Copying tensor of shape "
|
||||||
<< device_tensor.shape().DebugString() << " from "
|
<< device_tensor.shape().DebugString() << " from "
|
||||||
|
@ -488,15 +488,8 @@ Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
|
|||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
||||||
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
||||||
Notification n;
|
TF_RETURN_IF_ERROR(
|
||||||
device_context->CopyCPUTensorToDevice(
|
device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©));
|
||||||
&parsed, this, ©,
|
|
||||||
[&n, &status](const Status& s) {
|
|
||||||
status = s;
|
|
||||||
n.Notify();
|
|
||||||
},
|
|
||||||
true /*sync_dst_compute*/);
|
|
||||||
n.WaitForNotification();
|
|
||||||
*tensor = copy;
|
*tensor = copy;
|
||||||
}
|
}
|
||||||
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
|
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
|
||||||
|
@ -82,9 +82,8 @@ class UnaryOpsCompositionTest : public OpsTestBase {
|
|||||||
DeviceContext* device_context =
|
DeviceContext* device_context =
|
||||||
device_->tensorflow_gpu_device_info()->default_context;
|
device_->tensorflow_gpu_device_info()->default_context;
|
||||||
|
|
||||||
TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) {
|
TF_CHECK_OK(device_context->CopyCPUTensorToDeviceSync(&input_on_host,
|
||||||
device_context->CopyCPUTensorToDevice(&input_on_host, device_, input, cb);
|
device_, input));
|
||||||
}));
|
|
||||||
|
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
@ -94,27 +93,12 @@ class UnaryOpsCompositionTest : public OpsTestBase {
|
|||||||
Tensor* output = GetOutput(0);
|
Tensor* output = GetOutput(0);
|
||||||
Tensor output_on_host(cpu_allocator, output->dtype(), output->shape());
|
Tensor output_on_host(cpu_allocator, output->dtype(), output->shape());
|
||||||
|
|
||||||
TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) {
|
TF_CHECK_OK(device_context->CopyDeviceTensorToCPUSync(
|
||||||
device_context->CopyDeviceTensorToCPU(output, "output 0", device_,
|
output, "output 0", device_, &output_on_host));
|
||||||
&output_on_host, cb);
|
|
||||||
}));
|
|
||||||
|
|
||||||
test::ExpectClose(expected_tensor, output_on_host, /*atol=*/1e-5,
|
test::ExpectClose(expected_tensor, output_on_host, /*atol=*/1e-5,
|
||||||
/*rtol=*/1e-5);
|
/*rtol=*/1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename CopyFnTy>
|
|
||||||
Status BlockingCopy(CopyFnTy copy_fn) {
|
|
||||||
Notification n;
|
|
||||||
Status status;
|
|
||||||
copy_fn([&](Status s) {
|
|
||||||
status = s;
|
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
n.WaitForNotification();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) {
|
TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) {
|
||||||
|
@ -2295,6 +2295,7 @@ tf_cuda_library(
|
|||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"//tensorflow/core/example:feature_util",
|
"//tensorflow/core/example:feature_util",
|
||||||
|
@ -94,25 +94,14 @@ class GPUDeviceTest : public ::testing::Test {
|
|||||||
|
|
||||||
void CopyCPUToGPU(Tensor* cpu_tensor, Tensor* gpu_tensor, Device* device,
|
void CopyCPUToGPU(Tensor* cpu_tensor, Tensor* gpu_tensor, Device* device,
|
||||||
DeviceContext* device_context) {
|
DeviceContext* device_context) {
|
||||||
Notification note;
|
TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync(cpu_tensor, device,
|
||||||
device_context->CopyCPUTensorToDevice(cpu_tensor, device, gpu_tensor,
|
gpu_tensor));
|
||||||
[¬e](const Status& s) {
|
|
||||||
TF_ASSERT_OK(s);
|
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyGPUToCPU(Tensor* gpu_tensor, Tensor* cpu_tensor, Device* device,
|
void CopyGPUToCPU(Tensor* gpu_tensor, Tensor* cpu_tensor, Device* device,
|
||||||
DeviceContext* device_context) {
|
DeviceContext* device_context) {
|
||||||
Notification note;
|
TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync(
|
||||||
device_context->CopyDeviceTensorToCPU(gpu_tensor, /*tensor_name=*/"",
|
gpu_tensor, /*tensor_name=*/"", device, cpu_tensor));
|
||||||
device, cpu_tensor,
|
|
||||||
[¬e](const Status& s) {
|
|
||||||
TF_ASSERT_OK(s);
|
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -442,13 +442,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
Device* dev = instances_[broadcast_dev_id]->device_;
|
Device* dev = instances_[broadcast_dev_id]->device_;
|
||||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
|
||||||
t, "" /*tensor_name*/, dev, &cpu_copy,
|
t, "" /*tensor_name*/, dev, &cpu_copy));
|
||||||
[this, ¬ification](Status s) {
|
|
||||||
TF_CHECK_OK(s);
|
|
||||||
notification.Notify();
|
|
||||||
});
|
|
||||||
notification.WaitForNotification();
|
|
||||||
t = &cpu_copy;
|
t = &cpu_copy;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < t->NumElements(); ++i) {
|
for (size_t i = 0; i < t->NumElements(); ++i) {
|
||||||
@ -473,17 +468,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
if (device_type_ == DEVICE_CPU) {
|
if (device_type_ == DEVICE_CPU) {
|
||||||
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
||||||
} else if (device_type_ == DEVICE_GPU) {
|
} else if (device_type_ == DEVICE_GPU) {
|
||||||
Notification notification;
|
|
||||||
Device* dev = instances_[di]->device_;
|
Device* dev = instances_[di]->device_;
|
||||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
|
||||||
inst, "" /*tensor_name*/, dev, &actual,
|
inst, "" /*tensor_name*/, dev, &actual));
|
||||||
[this, ¬ification](Status s) {
|
|
||||||
TF_CHECK_OK(s);
|
|
||||||
notification.Notify();
|
|
||||||
});
|
|
||||||
notification.WaitForNotification();
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < tensor_len; ++i) {
|
for (int i = 0; i < tensor_len; ++i) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
@ -623,12 +612,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
Notification notification;
|
Notification notification;
|
||||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
dev_info->default_context->CopyCPUTensorToDevice(
|
TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
|
||||||
&cpu_tensor, device_, &tensor_, [¬ification](Status s) {
|
&cpu_tensor, device_, &tensor_));
|
||||||
TF_CHECK_OK(s);
|
|
||||||
notification.Notify();
|
|
||||||
});
|
|
||||||
notification.WaitForNotification();
|
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
||||||
}
|
}
|
||||||
|
@ -157,17 +157,11 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
DeviceContext* device_context =
|
DeviceContext* device_context =
|
||||||
gpu_device_->tensorflow_gpu_device_info()->default_context;
|
gpu_device_->tensorflow_gpu_device_info()->default_context;
|
||||||
|
|
||||||
Notification n;
|
|
||||||
Status status;
|
|
||||||
Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape());
|
Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape());
|
||||||
device_context->CopyDeviceTensorToCPU(&device_tensor, "", gpu_device_,
|
CHECK(device_context
|
||||||
&cpu_tensor,
|
->CopyDeviceTensorToCPUSync(&device_tensor, "", gpu_device_,
|
||||||
[&n, &status](const Status& s) {
|
&cpu_tensor)
|
||||||
status = s;
|
.ok());
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
n.WaitForNotification();
|
|
||||||
CHECK(status.ok());
|
|
||||||
return cpu_tensor;
|
return cpu_tensor;
|
||||||
#else
|
#else
|
||||||
CHECK(false);
|
CHECK(false);
|
||||||
@ -181,18 +175,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
DeviceContext* device_context =
|
DeviceContext* device_context =
|
||||||
gpu_device_->tensorflow_gpu_device_info()->default_context;
|
gpu_device_->tensorflow_gpu_device_info()->default_context;
|
||||||
|
|
||||||
Notification n;
|
|
||||||
Status status;
|
|
||||||
Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(),
|
Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(),
|
||||||
cpu_tensor.shape(), {});
|
cpu_tensor.shape(), {});
|
||||||
device_context->CopyCPUTensorToDevice(&cpu_tensor, gpu_device_,
|
CHECK(device_context
|
||||||
&device_tensor,
|
->CopyCPUTensorToDeviceSync(&cpu_tensor, gpu_device_,
|
||||||
[&n, &status](const Status& s) {
|
&device_tensor)
|
||||||
status = s;
|
.ok());
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
n.WaitForNotification();
|
|
||||||
CHECK(status.ok());
|
|
||||||
return device_tensor;
|
return device_tensor;
|
||||||
#else
|
#else
|
||||||
CHECK(false);
|
CHECK(false);
|
||||||
|
@ -254,14 +254,9 @@ string RingAlg::TensorDebugString(const Tensor& tensor) {
|
|||||||
col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
||||||
if (gpu_device_info) {
|
if (gpu_device_info) {
|
||||||
Tensor cpu_tensor(tensor.dtype(), tensor.shape());
|
Tensor cpu_tensor(tensor.dtype(), tensor.shape());
|
||||||
Notification note;
|
Status st = gpu_device_info->default_context->CopyDeviceTensorToCPUSync(
|
||||||
gpu_device_info->default_context->CopyDeviceTensorToCPU(
|
&tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
|
||||||
&tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
|
DCHECK(st.ok());
|
||||||
[¬e](const Status& s) {
|
|
||||||
DCHECK(s.ok());
|
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
return cpu_tensor.SummarizeValue(64);
|
return cpu_tensor.SummarizeValue(64);
|
||||||
} else {
|
} else {
|
||||||
return tensor.SummarizeValue(64);
|
return tensor.SummarizeValue(64);
|
||||||
|
@ -310,16 +310,13 @@ class RingGathererTest : public ::testing::Test {
|
|||||||
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
||||||
VLOG(1) << "actual " << actual.SummarizeValue(100);
|
VLOG(1) << "actual " << actual.SummarizeValue(100);
|
||||||
} else if (device_type_ == DEVICE_GPU) {
|
} else if (device_type_ == DEVICE_GPU) {
|
||||||
Notification note;
|
|
||||||
Device* dev = instances_[di]->device_;
|
Device* dev = instances_[di]->device_;
|
||||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
CHECK(dev_info->default_context
|
||||||
inst, "" /*tensor_name*/, dev, &actual, [¬e](const Status& s) {
|
->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
|
||||||
CHECK(s.ok());
|
&actual)
|
||||||
note.Notify();
|
.ok());
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto alias = actual.template unaligned_flat<T>();
|
auto alias = actual.template unaligned_flat<T>();
|
||||||
@ -433,13 +430,10 @@ class RingGathererTest : public ::testing::Test {
|
|||||||
init_f(&cpu_tensor);
|
init_f(&cpu_tensor);
|
||||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
Notification note;
|
CHECK(dev_info->default_context
|
||||||
dev_info->default_context->CopyCPUTensorToDevice(
|
->CopyCPUTensorToDeviceSync(&cpu_tensor, device_,
|
||||||
&cpu_tensor, device_, &input_tensor_, [¬e](const Status& s) {
|
&input_tensor_)
|
||||||
CHECK(s.ok());
|
.ok());
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
||||||
}
|
}
|
||||||
|
@ -331,16 +331,13 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
CHECK(actual.CopyFrom(*inst, inst->shape()));
|
||||||
VLOG(1) << "actual " << actual.SummarizeValue(100);
|
VLOG(1) << "actual " << actual.SummarizeValue(100);
|
||||||
} else if (device_type_ == DEVICE_GPU) {
|
} else if (device_type_ == DEVICE_GPU) {
|
||||||
Notification note;
|
|
||||||
Device* dev = instances_[di]->device_;
|
Device* dev = instances_[di]->device_;
|
||||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
CHECK(dev_info->default_context
|
||||||
inst, "" /*tensor_name*/, dev, &actual, [¬e](const Status& s) {
|
->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
|
||||||
CHECK(s.ok());
|
&actual)
|
||||||
note.Notify();
|
.ok());
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto alias = actual.template unaligned_flat<T>();
|
auto alias = actual.template unaligned_flat<T>();
|
||||||
@ -458,13 +455,9 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
init_f(&cpu_tensor);
|
init_f(&cpu_tensor);
|
||||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||||
CHECK(dev_info);
|
CHECK(dev_info);
|
||||||
Notification note;
|
CHECK(dev_info->default_context
|
||||||
dev_info->default_context->CopyCPUTensorToDevice(
|
->CopyCPUTensorToDeviceSync(&cpu_tensor, device_, &tensor_)
|
||||||
&cpu_tensor, device_, &tensor_, [¬e](const Status& s) {
|
.ok());
|
||||||
CHECK(s.ok());
|
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
LOG(FATAL) << "Unsupported device_type " << device_type_;
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/synchronization/notification.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
@ -33,6 +34,35 @@ DeviceBase::~DeviceBase() {
|
|||||||
eigen_cpu_devices_.clear();
|
eigen_cpu_devices_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
|
||||||
|
StringPiece tensor_name,
|
||||||
|
Device* device,
|
||||||
|
Tensor* cpu_tensor) {
|
||||||
|
absl::Notification n;
|
||||||
|
Status status;
|
||||||
|
CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
|
||||||
|
[&](const Status& s) {
|
||||||
|
status = s;
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor,
|
||||||
|
Device* device,
|
||||||
|
Tensor* device_tensor) const {
|
||||||
|
absl::Notification n;
|
||||||
|
Status status;
|
||||||
|
CopyCPUTensorToDevice(cpu_tensor, device, device_tensor,
|
||||||
|
[&](const Status& s) {
|
||||||
|
status = s;
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
const DeviceAttributes& DeviceBase::attributes() const {
|
const DeviceAttributes& DeviceBase::attributes() const {
|
||||||
LOG(FATAL) << "Device does not implement attributes()";
|
LOG(FATAL) << "Device does not implement attributes()";
|
||||||
}
|
}
|
||||||
|
@ -84,6 +84,10 @@ class DeviceContext : public core::RefCounted {
|
|||||||
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as CopyCPUTensorToDevice, but in a synchronous way.
|
||||||
|
Status CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, Device* device,
|
||||||
|
Tensor* device_tensor) const;
|
||||||
|
|
||||||
// Copies a tensor in this device.
|
// Copies a tensor in this device.
|
||||||
virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
|
virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
|
||||||
Device* device, Tensor* output_tensor,
|
Device* device, Tensor* output_tensor,
|
||||||
@ -100,6 +104,11 @@ class DeviceContext : public core::RefCounted {
|
|||||||
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
|
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as `CopyDeviceTensorToCPU`, but blocks until the copy is done.
|
||||||
|
Status CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
|
||||||
|
StringPiece tensor_name, Device* device,
|
||||||
|
Tensor* cpu_tensor);
|
||||||
|
|
||||||
// If possible, wait for all events on *stream to complete then execute func.
|
// If possible, wait for all events on *stream to complete then execute func.
|
||||||
// A non-OK Status is returned otherwise. The stream argument should be the
|
// A non-OK Status is returned otherwise. The stream argument should be the
|
||||||
// one provided by GpuDeviceInfo. This function is not applicable to devices
|
// one provided by GpuDeviceInfo. This function is not applicable to devices
|
||||||
|
@ -205,15 +205,10 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
VLOG(2) << "rank " << rank << " output " << output << " buf "
|
VLOG(2) << "rank " << rank << " output " << output << " buf "
|
||||||
<< DMAHelper::base(output);
|
<< DMAHelper::base(output);
|
||||||
Tensor actual(DT_FLOAT, TensorShape({output_length}));
|
Tensor actual(DT_FLOAT, TensorShape({output_length}));
|
||||||
Notification note;
|
|
||||||
Device* dev = instances_[rank]->device_;
|
Device* dev = instances_[rank]->device_;
|
||||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
|
||||||
output, /*tensor_name=*/"", dev, &actual, [¬e](const Status& s) {
|
output, /*tensor_name=*/"", dev, &actual));
|
||||||
TF_CHECK_OK(s);
|
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
VLOG(3) << "rank " << rank << " got output tensor "
|
VLOG(3) << "rank " << rank << " got output tensor "
|
||||||
<< actual.DebugString(output_length);
|
<< actual.DebugString(output_length);
|
||||||
for (int i = 0; i < output_length; ++i) {
|
for (int i = 0; i < output_length; ++i) {
|
||||||
@ -270,13 +265,8 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
VLOG(2) << "input tensor " << cpu_tensor.DebugString();
|
VLOG(2) << "input tensor " << cpu_tensor.DebugString();
|
||||||
}
|
}
|
||||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||||
Notification note;
|
TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
|
||||||
dev_info->default_context->CopyCPUTensorToDevice(
|
&cpu_tensor, device_, &input_));
|
||||||
&cpu_tensor, device_, &input_, [¬e](const Status& s) {
|
|
||||||
TF_CHECK_OK(s);
|
|
||||||
note.Notify();
|
|
||||||
});
|
|
||||||
note.WaitForNotification();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrepareDeviceContext(OpKernelContext::Params* params) {
|
void PrepareDeviceContext(OpKernelContext::Params* params) {
|
||||||
|
@ -447,14 +447,8 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
Device* device = down_cast<Device*>(ctx_->device());
|
Device* device = down_cast<Device*>(ctx_->device());
|
||||||
DeviceContext* device_ctx = ctx_->op_device_context();
|
DeviceContext* device_ctx = ctx_->op_device_context();
|
||||||
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
|
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
|
||||||
Notification done_copy;
|
s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"",
|
||||||
device_ctx->CopyDeviceTensorToCPU(
|
device, &cond_t);
|
||||||
&rets_[0], /*tensor_name=*/"", device, &cond_t,
|
|
||||||
[&done_copy, &s](const Status& status) {
|
|
||||||
s = status;
|
|
||||||
done_copy.Notify();
|
|
||||||
});
|
|
||||||
done_copy.WaitForNotification();
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return Finish(s);
|
return Finish(s);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user