Ignore tensors on custom devices when deciding in what device to execute ops.
This implicitly assumes custom devices always have lower priority, which is fine, I think. We might need to be smart about custom resource devices as TF code might assume resource-using ops are always colocated with the resources, though. PiperOrigin-RevId: 300638783 Change-Id: I6c4915bdacd3b873f46d2154077ebd7af2d6f0b5
This commit is contained in:
parent
465c3fae60
commit
ead7a372a8
@ -296,6 +296,76 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -863,6 +863,9 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
: absl::get<Device*>(op->Device());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
TensorHandle* tensor_handle = op->Inputs()[i];
|
||||
if (VariantDeviceIsCustom(tensor_handle->DeviceOrHostCPU(ctx))) {
|
||||
continue; // Do not try to let custom devices influence op placement.
|
||||
}
|
||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
Device* resource_device = tensor_handle->resource_device();
|
||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
|
Loading…
Reference in New Issue
Block a user