Enable eager tensor mirroring all the time
This greatly improve performance by reducing any repeated host to device copies needed for op execution. PiperOrigin-RevId: 297039608 Change-Id: Ic5f0d5c4ff15f1af096ba0ea20e96af76e53a6d8
This commit is contained in:
parent
e10fba7de3
commit
7fbfccbd30
@ -369,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
||||
void TensorHandleSilentCopy(bool async,
|
||||
TFE_ContextDevicePlacementPolicy global_policy,
|
||||
TFE_ContextDevicePlacementPolicy thread_policy,
|
||||
bool mirror, bool cpu_op) {
|
||||
bool cpu_op) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -392,12 +392,6 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
if (mirror) {
|
||||
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
@ -424,7 +418,7 @@ void TensorHandleSilentCopy(bool async,
|
||||
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
if (mirror) {
|
||||
if (!async) {
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
@ -454,27 +448,19 @@ void TensorHandleSilentCopy(bool async,
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopyCpu) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, true);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
}
|
||||
|
||||
void SetAndGetOpDevices(bool async) {
|
||||
|
@ -143,7 +143,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(false),
|
||||
implicit_mirroring_(false),
|
||||
implicit_mirroring_(true),
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
@ -164,7 +164,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(false),
|
||||
implicit_mirroring_(false),
|
||||
implicit_mirroring_(true),
|
||||
is_ready_(true),
|
||||
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
@ -185,7 +185,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(false),
|
||||
implicit_mirroring_(false),
|
||||
implicit_mirroring_(true),
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
// TODO(allenl): Figure out a better op_device story for custom devices,
|
||||
@ -220,7 +220,7 @@ TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(async),
|
||||
implicit_mirroring_(false),
|
||||
implicit_mirroring_(true),
|
||||
is_ready_(!async),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating empty Local TensorHandle: " << this
|
||||
@ -260,7 +260,7 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(true),
|
||||
is_async_(false),
|
||||
implicit_mirroring_(false),
|
||||
implicit_mirroring_(true),
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Remote TensorHandle: " << this
|
||||
@ -297,7 +297,7 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(true),
|
||||
is_async_(true),
|
||||
implicit_mirroring_(false),
|
||||
implicit_mirroring_(true),
|
||||
is_ready_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
||||
|
Loading…
Reference in New Issue
Block a user