Give custom devices the option to do type-based dispatch for ops with no explicit placement

When there is a custom device input and one or more physical device inputs to an op, presents the op to the custom device but indicates that the user did not explicitly request the placement (via the device property of the passed op).

Custom devices which want to stick to strict scope-based placement can either copy off the inputs and run the op on the default device or throw an error. The parallel device will stick with scope-only dispatch for now.

PiperOrigin-RevId: 328840123
Change-Id: Ic7490c0700a7ca5c74fd362211fa2fc9e008051c
This commit is contained in:
Allen Lavoie 2020-08-27 16:30:38 -07:00 committed by TensorFlower Gardener
parent e6f238fbaf
commit a764f3ab76
10 changed files with 152 additions and 83 deletions

View File

@ -1614,19 +1614,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
return status.status;
}
tensorflow::Status Execute(tensorflow::EagerOperation* op,
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> inputs;
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
@ -1636,10 +1629,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_DeleteTensorHandle(outputs[i]);
}
}
for (auto inp : inputs) {
TFE_DeleteTensorHandle(inp);
}
return status.status;
}

View File

@ -435,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2
// TODO(b/166642410): It would be nice, for custom devices and for other users,
// to have a non-string representation of devices (TF_Device) extracted from
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
#define TFE_CUSTOM_DEVICE_VERSION 3
// Struct to be filled in
typedef struct TFE_CustomDevice {
@ -454,9 +458,16 @@ typedef struct TFE_CustomDevice {
void* device_info);
// Method to execute an operation.
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
//
// Arguments provide enough information to reconstruct the original `TFE_Op`,
// or construct a transformed version, by inspecting the passed `op`.
//
// TFE_OpGetDevice(op) records the original placement of the operation. It may
// be an empty string if no device was explicitly requested, but will
// otherwise be the name of this custom device. Ops are placed onto a custom
// device if any of their inputs are on that custom device, but custom devices
// are free to set a bad status in order to require explicit placement.
void (*execute)(const TFE_Op* op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.

View File

@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived);
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom_device_name,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom0,
/*strict_scope_placement=*/false, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom1,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
// Custom device: mix of custom/physical places the op on the custom device.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
EXPECT_TRUE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Explicit placement still forces the op onto the requested device
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
// configured to do)
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}

View File

@ -33,6 +33,9 @@ struct LoggingDevice {
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
// If true, only explicit op placements are accepted. If false, uses
// type-based dispatch.
bool strict_scope_placement;
};
struct LoggedTensor {
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
if (dev->strict_scope_placement && *requested_placement == '\0') {
TF_SetStatus(s, TF_INTERNAL,
"Ops must be placed on the device explicitly, or their inputs "
"first copied to other devices.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
if (TF_GetCode(s) != TF_OK) return;
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
device->strict_scope_placement = strict_scope_placement;
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
logging_device->strict_scope_placement = true;
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);

View File

@ -255,28 +255,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, status);
if (*requested_placement == '\0') {
TF_SetStatus(
status, TF_INTERNAL,
"Ops must be placed on the parallel device explicitly, or their inputs "
"first un-packed. Got an un-placed op with an input placed on the "
"parallel device.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
if (TF_GetCode(status) != TF_OK) return;
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
TFE_TensorHandleDevicePointer(input, status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
typed_inputs.emplace_back(input);
}
}

View File

@ -114,7 +114,7 @@ class CustomDevice {
const string& target_device_name,
TensorHandle** result) = 0;
virtual Status Execute(EagerOperation* op, TensorHandle** retvals,
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
int* num_retvals) = 0;
};

View File

@ -208,12 +208,18 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
device = ctx_.HostCPU();
}
}
tensorflow::TensorHandle** retval_array =
reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
if (VariantDeviceIsCustom(device)) {
return absl::get<CustomDevice*>(device)->Execute(this, retval_array,
num_retvals);
}
if (device != kVariantDeviceNull) {
SetDevice(device);
}
return EagerExecute(
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
num_retvals);
return EagerExecute(this, retval_array, num_retvals);
}
} // namespace tensorflow

View File

@ -1070,11 +1070,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
if (VariantDeviceIsCustom(op->Device())) {
return absl::get<CustomDevice*>(op->Device())
->Execute(op, retvals, num_retvals);
}
if (!op->Executor().Async()) {
// In sync mode, always clear error to maintain the same behavior as before.
// TODO(b/141004939): Remove this.

View File

@ -185,34 +185,35 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
if (VariantDeviceIsCustom(op.Device())) {
*device = op.Device();
return Status::OK();
} else if (!op.DeviceName().empty()) {
// Don't override explicit placements.
return Status::OK();
}
// Ops are placed on a custom device if there's no other explicit requested
// placement and there is only one custom device in the op inputs.
if (!op.Inputs().empty()) {
// We keep track of what we've seen with devices instead of booleans to be
// able to provide a meaningful error message below.
VariantDevice first = op.Inputs()[0]->device();
VariantDevice different = first; // A different input device, if any.
VariantDevice custom = first; // The first custom device seen, or an
// arbitrary non-custom device otherwise.
for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) {
VariantDevice device = op.Inputs()[i]->device();
if (device != first) {
different = device;
}
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
custom = device;
}
if (different != first && VariantDeviceIsCustom(custom)) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same device. Operation ",
op.Name(), " has one input in custom device ",
VariantDeviceName(custom),
" and at least one input in a different device ",
VariantDeviceName(custom == first ? different : first)));
CustomDevice* first = nullptr;
for (const TensorHandle* input : op.Inputs()) {
if (VariantDeviceIsCustom(input->device())) {
CustomDevice* current = absl::get<CustomDevice*>(input->device());
if (first == nullptr) {
first = current;
} else if (first != current) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same custom device or another "
"physical device. Operation ",
op.Name(),
" has one input in custom "
"device ",
VariantDeviceName(first),
" and at least one input in a different custom device ",
VariantDeviceName(current)));
}
}
}
if (different == first && VariantDeviceIsCustom(custom)) {
if (first != nullptr) {
*device = first;
return Status::OK();
}