Automatically place operation on custom device when reasonably safe.

This changes the placement logic to use a custom device for an
operation if all of its inputs are on that device.

It also makes the placement fail if there are inputs on different devices, and at least one of them is custom.

PiperOrigin-RevId: 306977958
Change-Id: I91cf665d374fa5d0a2f9693d2813e590e06d0645
This commit is contained in:
Cesar Crusius 2020-04-16 21:05:04 -07:00 committed by TensorFlower Gardener
parent 9b24e4fa8b
commit 0ade89f3ed
5 changed files with 135 additions and 33 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
// A simple logging device to test custom device registration.
#include <memory>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -176,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -226,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), var_handle, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
ASSERT_EQ(
tensorflow::string(name),
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
TFE_DeleteTensorHandle(var_value);
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -246,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
ASSERT_FALSE(arrived);
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
arrived = false;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Base case: two CPU inputs executes fine.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in same custom device works.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in different custom devices fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -63,19 +63,10 @@ class EagerOperation : public AbstractOperationInterface {
// updated to that device.
Status SetDeviceName(const char* name) override;
void SetDevice(tensorflow::Device* device) {
void SetDevice(VariantDevice device) {
device_ = device;
device_name_ = device->name();
device_parsed_name_ = device->parsed_name();
// TODO(b/154133594): Due to intricacies of external logic, we can not
// set this do device_name_ as it would be natural, because we need the
// next call to SetDeviceName to reset the device pointer.
last_set_device_name_ = "\177"; // DEL (an invalid value)
}
void SetDevice(tensorflow::CustomDevice* device) {
device_ = device;
device_name_ = device->name();
device_name_ =
device == kVariantDeviceNull ? "" : VariantDeviceName(device);
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
// TODO(b/154133594): Due to intricacies of external logic, we can not
// set this do device_name_ as it would be natural, because we need the

View File

@ -852,7 +852,47 @@ bool IsPinnableOp(const string& op_type) {
// - All op inputs are on the CPU, small (<64 elements) and integers
// (int32/int64). This can be disabled by setting the environment variable
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
//
// TODO(b/154234908): Unify placement logic.
Status MaybeUpdateOpDevice(EagerOperation* op) {
// If operation was already placed on a custom device, use it.
if (VariantDeviceIsCustom(op->Device())) {
return Status::OK();
}
// If all the inputs are on the same custom device, use that custom
// device. Otherwise, it is an error to have a custom device as an input.
if (!op->Inputs().empty()) {
// We keep track of what we've seen with devices instead of booleans to be
// able to provide a meaningful error message below.
VariantDevice first = op->Inputs()[0]->device();
VariantDevice different = first; // A different input device, if any.
VariantDevice custom = first; // The first custom device seen, or an
// arbitrary non-custom device otherwise.
for (size_t i = 1; first == different && i < op->Inputs().size(); ++i) {
VariantDevice device = op->Inputs()[i]->device();
if (device != first) {
different = device;
}
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
custom = device;
}
if (different != first && VariantDeviceIsCustom(custom)) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same device. Operation ",
op->Name(), " has one input in custom device ",
VariantDeviceName(custom),
" and at least one input in a different device ",
VariantDeviceName(custom == first ? different : first)));
}
}
if (different == first && VariantDeviceIsCustom(custom)) {
op->SetDevice(first);
return Status::OK();
}
}
if (op->colocation_exempt()) {
return Status::OK();
}
@ -864,9 +904,6 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
: absl::get<Device*>(op->Device());
for (int i = 0; i < op->Inputs().size(); ++i) {
TensorHandle* tensor_handle = op->Inputs()[i];
if (VariantDeviceIsCustom(tensor_handle->DeviceOrHostCPU(ctx))) {
continue; // Do not try to let custom devices influence op placement.
}
if (tensor_handle->dtype == DT_RESOURCE) {
Device* resource_device = tensor_handle->resource_device();
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
@ -956,13 +993,13 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
if (VariantDeviceIsCustom(op->Device())) {
return absl::get<CustomDevice*>(op->Device())
->Execute(op, retvals, num_retvals);
}
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
if (!op->Executor().Async()) {
// In sync mode, always clear error to maintain the same behavior as before.
// TODO(b/141004939): Remove this.

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include <string>
#include <unordered_map>
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
@ -135,15 +136,7 @@ Status MklEagerOpRewrite::SetupNewOp(
->MutableAttrs()
->Set("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
if (orig_op->Device() == kVariantDeviceNull) {
string device_name = orig_op->DeviceName();
(*new_mkl_op)->SetDeviceName(device_name.c_str());
} else if (VariantDeviceIsCustom(orig_op->Device())) {
(*new_mkl_op)->SetDevice(absl::get<CustomDevice*>(orig_op->Device()));
} else {
(*new_mkl_op)->SetDevice(absl::get<Device*>(orig_op->Device()));
}
return Status::OK();
return (*new_mkl_op)->SetDeviceName(device_name.c_str());
}
Status MklEagerOpRewrite::CreateGenericMklOp(

View File

@ -792,6 +792,9 @@ bool VariantDeviceIsCustom(VariantDevice variant_device) {
}
string VariantDeviceName(VariantDevice device) {
if (device == kVariantDeviceNull) {
return "[]";
}
return absl::visit([](auto* device) { return device->name(); }, device);
}