diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index dd64475d7d6..adba99d37a4 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -560,7 +560,8 @@ Status CapturedFunction::Instantiate( if (!metadata_->use_inter_op_parallelism()) { inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; } - bool is_multi_device = metadata_->use_multi_device_function(); + bool is_multi_device = false; + TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device)); inst_opts.is_multi_device_function = is_multi_device; // We infer the target device from the function library runtime. @@ -863,5 +864,77 @@ CapturedFunction::CapturedFunction( : metadata_(std::move(metadata)), captured_inputs_(std::move(captured_inputs)) {} +Status CapturedFunction::IsMultiDevice(IteratorContext* ctx, + bool* is_multi_device) { + if (!metadata_->use_multi_device_function()) { + *is_multi_device = false; + return Status::OK(); + } + + const FunctionDef* fdef; + TF_RETURN_IF_ERROR( + LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef)); + + Device* current_device = ctx->flr()->device(); + DeviceType current_device_type(current_device->device_type()); + DeviceNameUtils::ParsedName current_device_name; + if (!DeviceNameUtils::ParseFullName(current_device->name(), + ¤t_device_name)) { + return errors::InvalidArgument("Failed to parse device name: ", + current_device->name()); + } + + // Check if any of the captured inputs are placed on a device not compatible + // with the current device. For non-captured inputs, we assume they are placed + // on the current device. + for (const auto& input : captured_inputs_) { + DataType dtype = input.dtype(); + if (dtype == DT_RESOURCE) { + const ResourceHandle& handle = input.flat()(0); + DeviceNameUtils::ParsedName resource_device_name; + if (!DeviceNameUtils::ParseFullName(handle.device(), + &resource_device_name)) { + return errors::InvalidArgument("Failed to parse device name: ", + handle.device()); + } + if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, + resource_device_name)) { + *is_multi_device = true; + return Status::OK(); + } + } + } + + // Check if all ops could be placed on the current device. + for (const auto& name : metadata_->lib_def()->ListFunctionNames()) { + const FunctionDef* fdef; + TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef)); + for (const auto& node : fdef->node_def()) { + // Check if the op has a kernel available for the current device. + if (!KernelDefAvailable(current_device_type, node)) { + *is_multi_device = true; + return Status::OK(); + } + // If the op has a requested device, check if the requested device is + // compatible with the current device. + if (!node.device().empty()) { + DeviceNameUtils::ParsedName node_device_name; + if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) { + return errors::InvalidArgument("Failed to parse device name: ", + node.device()); + } + if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, + node_device_name)) { + *is_multi_device = true; + return Status::OK(); + } + } + } + } + + *is_multi_device = false; + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index de424fc547c..284a02091dd 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -256,6 +256,10 @@ class CapturedFunction { CapturedFunction(std::shared_ptr metadata, std::vector captured_inputs); + // Determines whether the captured function requires the use of the + // multi-device function backend. + Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device); + const std::shared_ptr metadata_; const std::vector captured_inputs_;