Give custom devices the option to do type-based dispatch for ops with no explicit placement
When there is a custom device input and one or more physical device inputs to an op, presents the op to the custom device but indicates that the user did not explicitly request the placement (via the device property of the passed op). Custom devices which want to stick to strict scope-based placement can either copy off the inputs and run the op on the default device or throw an error. The parallel device will stick with scope-only dispatch for now. PiperOrigin-RevId: 328840123 Change-Id: Ic7490c0700a7ca5c74fd362211fa2fc9e008051c
This commit is contained in:
parent
e6f238fbaf
commit
a764f3ab76
@ -1614,19 +1614,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
||||
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> inputs;
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
@ -1636,10 +1629,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto inp : inputs) {
|
||||
TFE_DeleteTensorHandle(inp);
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
|
@ -435,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 2
|
||||
// TODO(b/166642410): It would be nice, for custom devices and for other users,
|
||||
// to have a non-string representation of devices (TF_Device) extracted from
|
||||
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 3
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
@ -454,9 +458,16 @@ typedef struct TFE_CustomDevice {
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
void (*execute)(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
//
|
||||
// Arguments provide enough information to reconstruct the original `TFE_Op`,
|
||||
// or construct a transformed version, by inspecting the passed `op`.
|
||||
//
|
||||
// TFE_OpGetDevice(op) records the original placement of the operation. It may
|
||||
// be an empty string if no device was explicitly requested, but will
|
||||
// otherwise be the name of this custom device. Ops are placed onto a custom
|
||||
// device if any of their inputs are on that custom device, but custom devices
|
||||
// are free to set a bad status in order to require explicit placement.
|
||||
void (*execute)(const TFE_Op* op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
|
@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
|
||||
ASSERT_FALSE(arrived);
|
||||
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom_device_name,
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom0,
|
||||
/*strict_scope_placement=*/false, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom1,
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
// Custom device: mix of custom/physical places the op on the custom device.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
EXPECT_TRUE(executed);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Explicit placement still forces the op onto the requested device
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
EXPECT_FALSE(executed);
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
|
||||
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
|
||||
// configured to do)
|
||||
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
EXPECT_FALSE(executed);
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(
|
||||
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
@ -33,6 +33,9 @@ struct LoggingDevice {
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
// If true, only explicit op placements are accepted. If false, uses
|
||||
// type-based dispatch.
|
||||
bool strict_scope_placement;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
const char* requested_placement = TFE_OpGetDevice(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
if (dev->strict_scope_placement && *requested_placement == '\0') {
|
||||
TF_SetStatus(s, TF_INTERNAL,
|
||||
"Ops must be placed on the device explicitly, or their inputs "
|
||||
"first copied to other devices.");
|
||||
return;
|
||||
}
|
||||
TFE_Context* context = TFE_OpGetContext(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const char* operation_name = TFE_OpGetName(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
|
||||
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
bool strict_scope_placement, bool* arrived_flag,
|
||||
bool* executed_flag, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
device->strict_scope_placement = strict_scope_placement;
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
logging_device->strict_scope_placement = true;
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
bool strict_scope_placement, bool* arrived_flag,
|
||||
bool* executed_flag, TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
|
@ -255,28 +255,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
const char* requested_placement = TFE_OpGetDevice(original_op, status);
|
||||
if (*requested_placement == '\0') {
|
||||
TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
"Ops must be placed on the parallel device explicitly, or their inputs "
|
||||
"first un-packed. Got an un-placed op with an input placed on the "
|
||||
"parallel device.");
|
||||
return;
|
||||
}
|
||||
TFE_Context* context = TFE_OpGetContext(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* operation_name = TFE_OpGetName(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
|
||||
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (named_device->name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
TFE_TensorHandleDevicePointer(inputs[i], status)));
|
||||
TFE_TensorHandleDevicePointer(input, status)));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
typed_inputs.emplace_back(inputs[i]);
|
||||
typed_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ class CustomDevice {
|
||||
const string& target_device_name,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status Execute(EagerOperation* op, TensorHandle** retvals,
|
||||
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) = 0;
|
||||
};
|
||||
|
||||
|
@ -208,12 +208,18 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
device = ctx_.HostCPU();
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle** retval_array =
|
||||
reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return absl::get<CustomDevice*>(device)->Execute(this, retval_array,
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
if (device != kVariantDeviceNull) {
|
||||
SetDevice(device);
|
||||
}
|
||||
return EagerExecute(
|
||||
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
|
||||
num_retvals);
|
||||
return EagerExecute(this, retval_array, num_retvals);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1070,11 +1070,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
|
||||
if (VariantDeviceIsCustom(op->Device())) {
|
||||
return absl::get<CustomDevice*>(op->Device())
|
||||
->Execute(op, retvals, num_retvals);
|
||||
}
|
||||
|
||||
if (!op->Executor().Async()) {
|
||||
// In sync mode, always clear error to maintain the same behavior as before.
|
||||
// TODO(b/141004939): Remove this.
|
||||
|
@ -185,34 +185,35 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||
if (VariantDeviceIsCustom(op.Device())) {
|
||||
*device = op.Device();
|
||||
return Status::OK();
|
||||
} else if (!op.DeviceName().empty()) {
|
||||
// Don't override explicit placements.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Ops are placed on a custom device if there's no other explicit requested
|
||||
// placement and there is only one custom device in the op inputs.
|
||||
if (!op.Inputs().empty()) {
|
||||
// We keep track of what we've seen with devices instead of booleans to be
|
||||
// able to provide a meaningful error message below.
|
||||
VariantDevice first = op.Inputs()[0]->device();
|
||||
VariantDevice different = first; // A different input device, if any.
|
||||
VariantDevice custom = first; // The first custom device seen, or an
|
||||
// arbitrary non-custom device otherwise.
|
||||
for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) {
|
||||
VariantDevice device = op.Inputs()[i]->device();
|
||||
if (device != first) {
|
||||
different = device;
|
||||
}
|
||||
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
|
||||
custom = device;
|
||||
}
|
||||
if (different != first && VariantDeviceIsCustom(custom)) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same device. Operation ",
|
||||
op.Name(), " has one input in custom device ",
|
||||
VariantDeviceName(custom),
|
||||
" and at least one input in a different device ",
|
||||
VariantDeviceName(custom == first ? different : first)));
|
||||
CustomDevice* first = nullptr;
|
||||
for (const TensorHandle* input : op.Inputs()) {
|
||||
if (VariantDeviceIsCustom(input->device())) {
|
||||
CustomDevice* current = absl::get<CustomDevice*>(input->device());
|
||||
if (first == nullptr) {
|
||||
first = current;
|
||||
} else if (first != current) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same custom device or another "
|
||||
"physical device. Operation ",
|
||||
op.Name(),
|
||||
" has one input in custom "
|
||||
"device ",
|
||||
VariantDeviceName(first),
|
||||
" and at least one input in a different custom device ",
|
||||
VariantDeviceName(current)));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (different == first && VariantDeviceIsCustom(custom)) {
|
||||
if (first != nullptr) {
|
||||
*device = first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user