[tf.data] Switching to using multi-device function by default.
PiperOrigin-RevId: 312830323 Change-Id: I9e1ae4aea3ab230f06a26dc79a17fc3aa66ca422
This commit is contained in:
parent
f8c0e68a8a
commit
0c83272451
|
@ -560,8 +560,7 @@ Status CapturedFunction::Instantiate(
|
|||
if (!metadata_->use_inter_op_parallelism()) {
|
||||
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
|
||||
}
|
||||
bool is_multi_device = false;
|
||||
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
|
||||
bool is_multi_device = metadata_->use_multi_device_function();
|
||||
inst_opts.is_multi_device_function = is_multi_device;
|
||||
|
||||
// We infer the target device from the function library runtime.
|
||||
|
@ -864,77 +863,5 @@ CapturedFunction::CapturedFunction(
|
|||
: metadata_(std::move(metadata)),
|
||||
captured_inputs_(std::move(captured_inputs)) {}
|
||||
|
||||
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
|
||||
bool* is_multi_device) {
|
||||
if (!metadata_->use_multi_device_function()) {
|
||||
*is_multi_device = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const FunctionDef* fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
|
||||
|
||||
Device* current_device = ctx->flr()->device();
|
||||
DeviceType current_device_type(current_device->device_type());
|
||||
DeviceNameUtils::ParsedName current_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(current_device->name(),
|
||||
¤t_device_name)) {
|
||||
return errors::InvalidArgument("Failed to parse device name: ",
|
||||
current_device->name());
|
||||
}
|
||||
|
||||
// Check if any of the captured inputs are placed on a device not compatible
|
||||
// with the current device. For non-captured inputs, we assume they are placed
|
||||
// on the current device.
|
||||
for (const auto& input : captured_inputs_) {
|
||||
DataType dtype = input.dtype();
|
||||
if (dtype == DT_RESOURCE) {
|
||||
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
|
||||
DeviceNameUtils::ParsedName resource_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(handle.device(),
|
||||
&resource_device_name)) {
|
||||
return errors::InvalidArgument("Failed to parse device name: ",
|
||||
handle.device());
|
||||
}
|
||||
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
|
||||
resource_device_name)) {
|
||||
*is_multi_device = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all ops could be placed on the current device.
|
||||
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
|
||||
const FunctionDef* fdef;
|
||||
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
|
||||
for (const auto& node : fdef->node_def()) {
|
||||
// Check if the op has a kernel available for the current device.
|
||||
if (!KernelDefAvailable(current_device_type, node)) {
|
||||
*is_multi_device = true;
|
||||
return Status::OK();
|
||||
}
|
||||
// If the op has a requested device, check if the requested device is
|
||||
// compatible with the current device.
|
||||
if (!node.device().empty()) {
|
||||
DeviceNameUtils::ParsedName node_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
|
||||
return errors::InvalidArgument("Failed to parse device name: ",
|
||||
node.device());
|
||||
}
|
||||
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
|
||||
node_device_name)) {
|
||||
*is_multi_device = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*is_multi_device = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -256,10 +256,6 @@ class CapturedFunction {
|
|||
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
|
||||
std::vector<Tensor> captured_inputs);
|
||||
|
||||
// Determines whether the captured function requires the use of the
|
||||
// multi-device function backend.
|
||||
Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device);
|
||||
|
||||
const std::shared_ptr<const FunctionMetadata> metadata_;
|
||||
const std::vector<Tensor> captured_inputs_;
|
||||
|
||||
|
|
Loading…
Reference in New Issue