[tf.data] Switch tf.data functions to default to using multi-device backend.
PiperOrigin-RevId: 309339306 Change-Id: I1ee62e80823a2d17ec7a8706af6aa5b6bce45188
This commit is contained in:
parent
196b03bb55
commit
624f1a0f82
tensorflow
@ -551,8 +551,7 @@ Status CapturedFunction::Instantiate(
|
|||||||
if (!metadata_->use_inter_op_parallelism()) {
|
if (!metadata_->use_inter_op_parallelism()) {
|
||||||
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
|
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
|
||||||
}
|
}
|
||||||
bool is_multi_device = false;
|
bool is_multi_device = metadata_->use_multi_device_function();
|
||||||
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
|
|
||||||
inst_opts.is_multi_device_function = is_multi_device;
|
inst_opts.is_multi_device_function = is_multi_device;
|
||||||
|
|
||||||
// We infer the target device from the function library runtime.
|
// We infer the target device from the function library runtime.
|
||||||
@ -671,13 +670,20 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
|||||||
|
|
||||||
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
|
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
|
||||||
ret_types_);
|
ret_types_);
|
||||||
|
Notification n;
|
||||||
|
Status s;
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&] {
|
[&] {
|
||||||
return absl::StrCat(
|
return absl::StrCat(
|
||||||
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
||||||
},
|
},
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
|
||||||
|
s.Update(func_status);
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
return frame.ConsumeRetvals(rets);
|
return frame.ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -702,6 +708,9 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
|||||||
|
|
||||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||||
ret_types_);
|
ret_types_);
|
||||||
|
Notification n;
|
||||||
|
Status s;
|
||||||
|
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&] {
|
[&] {
|
||||||
return absl::StrCat(
|
return absl::StrCat(
|
||||||
@ -709,7 +718,12 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
|||||||
f_opts.step_id, "#");
|
f_opts.step_id, "#");
|
||||||
},
|
},
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
|
||||||
|
s.Update(func_status);
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
return frame.ConsumeRetvals(rets);
|
return frame.ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -733,13 +747,21 @@ Status InstantiatedCapturedFunction::RunInstantiated(
|
|||||||
|
|
||||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||||
ret_types_);
|
ret_types_);
|
||||||
|
Notification n;
|
||||||
|
Status s;
|
||||||
|
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&] {
|
[&] {
|
||||||
return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
|
return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
|
||||||
f_opts.step_id, "#");
|
f_opts.step_id, "#");
|
||||||
},
|
},
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
|
||||||
|
s.Update(func_status);
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
return frame.ConsumeRetvals(rets);
|
return frame.ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -855,77 +877,5 @@ CapturedFunction::CapturedFunction(
|
|||||||
: metadata_(std::move(metadata)),
|
: metadata_(std::move(metadata)),
|
||||||
captured_inputs_(std::move(captured_inputs)) {}
|
captured_inputs_(std::move(captured_inputs)) {}
|
||||||
|
|
||||||
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
|
|
||||||
bool* is_multi_device) {
|
|
||||||
if (!metadata_->use_multi_device_function()) {
|
|
||||||
*is_multi_device = false;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
const FunctionDef* fdef;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
|
|
||||||
|
|
||||||
Device* current_device = ctx->flr()->device();
|
|
||||||
DeviceType current_device_type(current_device->device_type());
|
|
||||||
DeviceNameUtils::ParsedName current_device_name;
|
|
||||||
if (!DeviceNameUtils::ParseFullName(current_device->name(),
|
|
||||||
¤t_device_name)) {
|
|
||||||
return errors::InvalidArgument("Failed to parse device name: ",
|
|
||||||
current_device->name());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any of the captured inputs are placed on a device not compatible
|
|
||||||
// with the current device. For non-captured inputs, we assume they are placed
|
|
||||||
// on the current device.
|
|
||||||
for (const auto& input : captured_inputs_) {
|
|
||||||
DataType dtype = input.dtype();
|
|
||||||
if (dtype == DT_RESOURCE) {
|
|
||||||
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
|
|
||||||
DeviceNameUtils::ParsedName resource_device_name;
|
|
||||||
if (!DeviceNameUtils::ParseFullName(handle.device(),
|
|
||||||
&resource_device_name)) {
|
|
||||||
return errors::InvalidArgument("Failed to parse device name: ",
|
|
||||||
handle.device());
|
|
||||||
}
|
|
||||||
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
|
|
||||||
resource_device_name)) {
|
|
||||||
*is_multi_device = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if all ops could be placed on the current device.
|
|
||||||
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
|
|
||||||
const FunctionDef* fdef;
|
|
||||||
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
|
|
||||||
for (const auto& node : fdef->node_def()) {
|
|
||||||
// Check if the op has a kernel available for the current device.
|
|
||||||
if (!KernelDefAvailable(current_device_type, node)) {
|
|
||||||
*is_multi_device = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
// If the op has a requested device, check if the requested device is
|
|
||||||
// compatible with the current device.
|
|
||||||
if (!node.device().empty()) {
|
|
||||||
DeviceNameUtils::ParsedName node_device_name;
|
|
||||||
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
|
|
||||||
return errors::InvalidArgument("Failed to parse device name: ",
|
|
||||||
node.device());
|
|
||||||
}
|
|
||||||
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
|
|
||||||
node_device_name)) {
|
|
||||||
*is_multi_device = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*is_multi_device = false;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -256,10 +256,6 @@ class CapturedFunction {
|
|||||||
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
|
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
|
||||||
std::vector<Tensor> captured_inputs);
|
std::vector<Tensor> captured_inputs);
|
||||||
|
|
||||||
// Determines whether the captured function requires the use of the
|
|
||||||
// multi-device function backend.
|
|
||||||
Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device);
|
|
||||||
|
|
||||||
const std::shared_ptr<const FunctionMetadata> metadata_;
|
const std::shared_ptr<const FunctionMetadata> metadata_;
|
||||||
const std::vector<Tensor> captured_inputs_;
|
const std::vector<Tensor> captured_inputs_;
|
||||||
|
|
||||||
|
@ -115,6 +115,15 @@ class MapBenchmark(benchmark_base.DatasetBenchmarkBase):
|
|||||||
name="parallel_control_flow",
|
name="parallel_control_flow",
|
||||||
apply_default_optimizations=True)
|
apply_default_optimizations=True)
|
||||||
|
|
||||||
|
def benchmark_execution_overhead(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(100000)
|
||||||
|
dataset = dataset_ops.MapDataset(
|
||||||
|
dataset, lambda x: x + 1, use_inter_op_parallelism=False)
|
||||||
|
self.run_and_report_benchmark(
|
||||||
|
dataset,
|
||||||
|
num_elements=100000,
|
||||||
|
name="execution_overhead")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
benchmark_base.test.main()
|
benchmark_base.test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user