[tf.data] Switch tf.data functions to default to using multi-device backend.

PiperOrigin-RevId: 309476584
Change-Id: I250ab2972a94cd9ce7527e2fb3de215d93b15c85
This commit is contained in:
Bruce Fontaine 2020-05-01 14:40:33 -07:00 committed by TensorFlower Gardener
parent 1a1346145d
commit 91e3f65507
3 changed files with 81 additions and 36 deletions

View File

@ -551,7 +551,8 @@ Status CapturedFunction::Instantiate(
if (!metadata_->use_inter_op_parallelism()) {
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
}
bool is_multi_device = metadata_->use_multi_device_function();
bool is_multi_device = false;
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
inst_opts.is_multi_device_function = is_multi_device;
// We infer the target device from the function library runtime.
@ -670,20 +671,13 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
ret_types_);
Notification n;
Status s;
profiler::TraceMe activity(
[&] {
return absl::StrCat(
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
s.Update(func_status);
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(s);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
return frame.ConsumeRetvals(rets);
}
@ -708,9 +702,6 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_);
Notification n;
Status s;
profiler::TraceMe activity(
[&] {
return absl::StrCat(
@ -718,12 +709,7 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
s.Update(func_status);
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(s);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
return frame.ConsumeRetvals(rets);
}
@ -747,21 +733,13 @@ Status InstantiatedCapturedFunction::RunInstantiated(
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_);
Notification n;
Status s;
profiler::TraceMe activity(
[&] {
return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
s.Update(func_status);
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(s);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
return frame.ConsumeRetvals(rets);
}
@ -877,5 +855,77 @@ CapturedFunction::CapturedFunction(
: metadata_(std::move(metadata)),
captured_inputs_(std::move(captured_inputs)) {}
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
bool* is_multi_device) {
if (!metadata_->use_multi_device_function()) {
*is_multi_device = false;
return Status::OK();
}
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
Device* current_device = ctx->flr()->device();
DeviceType current_device_type(current_device->device_type());
DeviceNameUtils::ParsedName current_device_name;
if (!DeviceNameUtils::ParseFullName(current_device->name(),
&current_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
current_device->name());
}
// Check if any of the captured inputs are placed on a device not compatible
// with the current device. For non-captured inputs, we assume they are placed
// on the current device.
for (const auto& input : captured_inputs_) {
DataType dtype = input.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
DeviceNameUtils::ParsedName resource_device_name;
if (!DeviceNameUtils::ParseFullName(handle.device(),
&resource_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
handle.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
resource_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
// Check if all ops could be placed on the current device.
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
for (const auto& node : fdef->node_def()) {
// Check if the op has a kernel available for the current device.
if (!KernelDefAvailable(current_device_type, node)) {
*is_multi_device = true;
return Status::OK();
}
// If the op has a requested device, check if the requested device is
// compatible with the current device.
if (!node.device().empty()) {
DeviceNameUtils::ParsedName node_device_name;
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
node.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
node_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
}
*is_multi_device = false;
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -256,6 +256,10 @@ class CapturedFunction {
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs);
// Determines whether the captured function requires the use of the
// multi-device function backend.
Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device);
const std::shared_ptr<const FunctionMetadata> metadata_;
const std::vector<Tensor> captured_inputs_;

View File

@ -115,15 +115,6 @@ class MapBenchmark(benchmark_base.DatasetBenchmarkBase):
name="parallel_control_flow",
apply_default_optimizations=True)
def benchmark_execution_overhead(self):
dataset = dataset_ops.Dataset.range(100000)
dataset = dataset_ops.MapDataset(
dataset, lambda x: x + 1, use_inter_op_parallelism=False)
self.run_and_report_benchmark(
dataset,
num_elements=100000,
name="execution_overhead")
if __name__ == "__main__":
benchmark_base.test.main()