Automatically place operation on custom device when reasonably safe.
This changes the placement logic to use a custom device for an operation if all of its inputs are on that device. It also makes the placement fail if there are inputs on different devices, and at least one of them is custom. PiperOrigin-RevId: 306977958 Change-Id: I91cf665d374fa5d0a2f9693d2813e590e06d0645
This commit is contained in:
parent
9b24e4fa8b
commit
0ade89f3ed
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
@ -25,7 +26,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -176,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
|||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
|
@ -226,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
|||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
ASSERT_EQ(
|
||||
tensorflow::string(name),
|
||||
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
|
||||
TFE_DeleteTensorHandle(var_value);
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
|
@ -246,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
|||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
|
||||
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
|
||||
ASSERT_FALSE(arrived);
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
arrived = false;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Base case: two CPU inputs executes fine.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
|
||||
TFE_TensorHandle* retval;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in same custom device works.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in different custom devices fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
|
|
@ -63,19 +63,10 @@ class EagerOperation : public AbstractOperationInterface {
|
|||
// updated to that device.
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
void SetDevice(tensorflow::Device* device) {
|
||||
void SetDevice(VariantDevice device) {
|
||||
device_ = device;
|
||||
device_name_ = device->name();
|
||||
device_parsed_name_ = device->parsed_name();
|
||||
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||
// set this do device_name_ as it would be natural, because we need the
|
||||
// next call to SetDeviceName to reset the device pointer.
|
||||
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
||||
}
|
||||
|
||||
void SetDevice(tensorflow::CustomDevice* device) {
|
||||
device_ = device;
|
||||
device_name_ = device->name();
|
||||
device_name_ =
|
||||
device == kVariantDeviceNull ? "" : VariantDeviceName(device);
|
||||
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
|
||||
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||
// set this do device_name_ as it would be natural, because we need the
|
||||
|
|
|
@ -852,7 +852,47 @@ bool IsPinnableOp(const string& op_type) {
|
|||
// - All op inputs are on the CPU, small (<64 elements) and integers
|
||||
// (int32/int64). This can be disabled by setting the environment variable
|
||||
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
|
||||
//
|
||||
// TODO(b/154234908): Unify placement logic.
|
||||
Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
// If operation was already placed on a custom device, use it.
|
||||
if (VariantDeviceIsCustom(op->Device())) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If all the inputs are on the same custom device, use that custom
|
||||
// device. Otherwise, it is an error to have a custom device as an input.
|
||||
if (!op->Inputs().empty()) {
|
||||
// We keep track of what we've seen with devices instead of booleans to be
|
||||
// able to provide a meaningful error message below.
|
||||
VariantDevice first = op->Inputs()[0]->device();
|
||||
VariantDevice different = first; // A different input device, if any.
|
||||
VariantDevice custom = first; // The first custom device seen, or an
|
||||
// arbitrary non-custom device otherwise.
|
||||
for (size_t i = 1; first == different && i < op->Inputs().size(); ++i) {
|
||||
VariantDevice device = op->Inputs()[i]->device();
|
||||
if (device != first) {
|
||||
different = device;
|
||||
}
|
||||
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
|
||||
custom = device;
|
||||
}
|
||||
if (different != first && VariantDeviceIsCustom(custom)) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same device. Operation ",
|
||||
op->Name(), " has one input in custom device ",
|
||||
VariantDeviceName(custom),
|
||||
" and at least one input in a different device ",
|
||||
VariantDeviceName(custom == first ? different : first)));
|
||||
}
|
||||
}
|
||||
if (different == first && VariantDeviceIsCustom(custom)) {
|
||||
op->SetDevice(first);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (op->colocation_exempt()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -864,9 +904,6 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
|||
: absl::get<Device*>(op->Device());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
TensorHandle* tensor_handle = op->Inputs()[i];
|
||||
if (VariantDeviceIsCustom(tensor_handle->DeviceOrHostCPU(ctx))) {
|
||||
continue; // Do not try to let custom devices influence op placement.
|
||||
}
|
||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
Device* resource_device = tensor_handle->resource_device();
|
||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
|
@ -956,13 +993,13 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
|||
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
|
||||
|
||||
if (VariantDeviceIsCustom(op->Device())) {
|
||||
return absl::get<CustomDevice*>(op->Device())
|
||||
->Execute(op, retvals, num_retvals);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
|
||||
|
||||
if (!op->Executor().Async()) {
|
||||
// In sync mode, always clear error to maintain the same behavior as before.
|
||||
// TODO(b/141004939): Remove this.
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifdef INTEL_MKL
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
|
@ -135,15 +136,7 @@ Status MklEagerOpRewrite::SetupNewOp(
|
|||
->MutableAttrs()
|
||||
->Set("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
|
||||
|
||||
if (orig_op->Device() == kVariantDeviceNull) {
|
||||
string device_name = orig_op->DeviceName();
|
||||
(*new_mkl_op)->SetDeviceName(device_name.c_str());
|
||||
} else if (VariantDeviceIsCustom(orig_op->Device())) {
|
||||
(*new_mkl_op)->SetDevice(absl::get<CustomDevice*>(orig_op->Device()));
|
||||
} else {
|
||||
(*new_mkl_op)->SetDevice(absl::get<Device*>(orig_op->Device()));
|
||||
}
|
||||
return Status::OK();
|
||||
return (*new_mkl_op)->SetDeviceName(device_name.c_str());
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::CreateGenericMklOp(
|
||||
|
|
|
@ -792,6 +792,9 @@ bool VariantDeviceIsCustom(VariantDevice variant_device) {
|
|||
}
|
||||
|
||||
string VariantDeviceName(VariantDevice device) {
|
||||
if (device == kVariantDeviceNull) {
|
||||
return "[]";
|
||||
}
|
||||
return absl::visit([](auto* device) { return device->name(); }, device);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue