From 624f1a0f82e79fc92c27b39c5db95a0a45b7a9d7 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 30 Apr 2020 18:28:02 -0700 Subject: [PATCH] [tf.data] Switch tf.data functions to default to using multi-device backend. PiperOrigin-RevId: 309339306 Change-Id: I1ee62e80823a2d17ec7a8706af6aa5b6bce45188 --- .../core/kernels/data/captured_function.cc | 104 +++++------------- .../core/kernels/data/captured_function.h | 4 - .../python/data/benchmarks/map_benchmark.py | 9 ++ 3 files changed, 36 insertions(+), 81 deletions(-) diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index b4244c95395..646896ddbb7 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -551,8 +551,7 @@ Status CapturedFunction::Instantiate( if (!metadata_->use_inter_op_parallelism()) { inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; } - bool is_multi_device = false; - TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device)); + bool is_multi_device = metadata_->use_multi_device_function(); inst_opts.is_multi_device_function = is_multi_device; // We infer the target device from the function library runtime. @@ -671,13 +670,20 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx, OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(), ret_types_); + Notification n; + Status s; profiler::TraceMe activity( [&] { return absl::StrCat( "InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#"); }, profiler::TraceMeLevel::kInfo); - TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); + lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) { + s.Update(func_status); + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(s); return frame.ConsumeRetvals(rets); } @@ -702,6 +708,9 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs( BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(), ret_types_); + Notification n; + Status s; + profiler::TraceMe activity( [&] { return absl::StrCat( @@ -709,7 +718,12 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs( f_opts.step_id, "#"); }, profiler::TraceMeLevel::kInfo); - TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); + lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) { + s.Update(func_status); + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(s); return frame.ConsumeRetvals(rets); } @@ -733,13 +747,21 @@ Status InstantiatedCapturedFunction::RunInstantiated( BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(), ret_types_); + Notification n; + Status s; + profiler::TraceMe activity( [&] { return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=", f_opts.step_id, "#"); }, profiler::TraceMeLevel::kInfo); - TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); + lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) { + s.Update(func_status); + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(s); return frame.ConsumeRetvals(rets); } @@ -855,77 +877,5 @@ CapturedFunction::CapturedFunction( : metadata_(std::move(metadata)), captured_inputs_(std::move(captured_inputs)) {} -Status CapturedFunction::IsMultiDevice(IteratorContext* ctx, - bool* is_multi_device) { - if (!metadata_->use_multi_device_function()) { - *is_multi_device = false; - return Status::OK(); - } - - const FunctionDef* fdef; - TF_RETURN_IF_ERROR( - LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef)); - - Device* current_device = ctx->flr()->device(); - DeviceType current_device_type(current_device->device_type()); - DeviceNameUtils::ParsedName current_device_name; - if (!DeviceNameUtils::ParseFullName(current_device->name(), - ¤t_device_name)) { - return errors::InvalidArgument("Failed to parse device name: ", - current_device->name()); - } - - // Check if any of the captured inputs are placed on a device not compatible - // with the current device. For non-captured inputs, we assume they are placed - // on the current device. - for (const auto& input : captured_inputs_) { - DataType dtype = input.dtype(); - if (dtype == DT_RESOURCE) { - const ResourceHandle& handle = input.flat()(0); - DeviceNameUtils::ParsedName resource_device_name; - if (!DeviceNameUtils::ParseFullName(handle.device(), - &resource_device_name)) { - return errors::InvalidArgument("Failed to parse device name: ", - handle.device()); - } - if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, - resource_device_name)) { - *is_multi_device = true; - return Status::OK(); - } - } - } - - // Check if all ops could be placed on the current device. - for (const auto& name : metadata_->lib_def()->ListFunctionNames()) { - const FunctionDef* fdef; - TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef)); - for (const auto& node : fdef->node_def()) { - // Check if the op has a kernel available for the current device. - if (!KernelDefAvailable(current_device_type, node)) { - *is_multi_device = true; - return Status::OK(); - } - // If the op has a requested device, check if the requested device is - // compatible with the current device. - if (!node.device().empty()) { - DeviceNameUtils::ParsedName node_device_name; - if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) { - return errors::InvalidArgument("Failed to parse device name: ", - node.device()); - } - if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, - node_device_name)) { - *is_multi_device = true; - return Status::OK(); - } - } - } - } - - *is_multi_device = false; - return Status::OK(); -} - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 284a02091dd..de424fc547c 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -256,10 +256,6 @@ class CapturedFunction { CapturedFunction(std::shared_ptr metadata, std::vector captured_inputs); - // Determines whether the captured function requires the use of the - // multi-device function backend. - Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device); - const std::shared_ptr metadata_; const std::vector captured_inputs_; diff --git a/tensorflow/python/data/benchmarks/map_benchmark.py b/tensorflow/python/data/benchmarks/map_benchmark.py index aea0fe9847e..39c0c544e78 100644 --- a/tensorflow/python/data/benchmarks/map_benchmark.py +++ b/tensorflow/python/data/benchmarks/map_benchmark.py @@ -115,6 +115,15 @@ class MapBenchmark(benchmark_base.DatasetBenchmarkBase): name="parallel_control_flow", apply_default_optimizations=True) + def benchmark_execution_overhead(self): + dataset = dataset_ops.Dataset.range(100000) + dataset = dataset_ops.MapDataset( + dataset, lambda x: x + 1, use_inter_op_parallelism=False) + self.run_and_report_benchmark( + dataset, + num_elements=100000, + name="execution_overhead") + if __name__ == "__main__": benchmark_base.test.main()