diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 931a6626c9a..b6e6369bb43 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -296,6 +296,76 @@ TEST(CUSTOM_DEVICE, MakeVariable) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); } +TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + bool arrived = false; + bool executed = false; + const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a variable handle placed on the custom device. + std::unique_ptr op( + TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get()); + TFE_OpSetAttrString(op.get(), "container", "", 0); + TFE_OpSetAttrString(op.get(), "shared_name", "", 0); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + executed = false; + TFE_Execute(op.get(), &var_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + auto handle_cleaner = tensorflow::gtl::MakeCleanup( + [var_handle]() { TFE_DeleteTensorHandle(var_handle); }); + + // Assign to the variable, copying to the custom device. + std::unique_ptr one( + TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle); + op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get())); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpAddInput(op.get(), one.get(), status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + executed = false; + num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + + // Read the variable's value. + op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get())); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + executed = false; + num_retvals = 1; + TFE_TensorHandle* var_value = nullptr; + TFE_Execute(op.get(), &var_value, &num_retvals, status.get()); + EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK) + << "Execution should fail because the variable is being used on the " + "wrong device."; + // Free the backing buffer for the variable. + op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get())); + TFE_OpAddInput(op.get(), var_handle, status.get()); + TFE_OpSetDevice(op.get(), name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); +} + TEST(CUSTOM_DEVICE, InvalidRegistrationError) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 8ac4eb35fdf..baaddec74e1 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -863,6 +863,9 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { : absl::get(op->Device()); for (int i = 0; i < op->Inputs().size(); ++i) { TensorHandle* tensor_handle = op->Inputs()[i]; + if (VariantDeviceIsCustom(tensor_handle->DeviceOrHostCPU(ctx))) { + continue; // Do not try to let custom devices influence op placement. + } if (tensor_handle->dtype == DT_RESOURCE) { Device* resource_device = tensor_handle->resource_device(); DVLOG(2) << "for op " << op->Name() << " input " << i << " "