diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index d79cb25ec8b..f740d7ff1ad 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -565,8 +565,7 @@ Status CapturedFunction::Instantiate( if (!metadata_->use_inter_op_parallelism()) { inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; } - bool is_multi_device = metadata_->use_multi_device_function(); - inst_opts.is_multi_device_function = is_multi_device; + inst_opts.is_multi_device_function = metadata_->use_multi_device_function(); // We infer the target device from the function library runtime. DCHECK(lib->device() != nullptr); @@ -649,16 +648,13 @@ Status CapturedFunction::Instantiate( DataTypeVector ret_types; TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types)); - *instantiated_captured_function = - absl::WrapUnique( - new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types), - *ctx->runner(), this, - is_multi_device)); - return Status::OK(); + bool is_multi_device; + TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device)); + return InstantiatedCapturedFunction::Create( + lib, f_handle, std::move(ret_types), *ctx->runner(), this, + is_multi_device, instantiated_captured_function); } -bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); } - Status CapturedFunction::CheckExternalState() const { for (const auto& name : lib_def()->ListFunctionNames()) { TF_RETURN_IF_ERROR( @@ -667,6 +663,95 @@ Status CapturedFunction::CheckExternalState() const { return Status::OK(); } +CapturedFunction::CapturedFunction( + std::shared_ptr metadata, + std::vector captured_inputs) + : metadata_(std::move(metadata)), + captured_inputs_(std::move(captured_inputs)) {} + +Status CapturedFunction::IsMultiDevice(IteratorContext* ctx, + bool* is_multi_device) const { + if (!metadata_->use_multi_device_function()) { + *is_multi_device = false; + return Status::OK(); + } + + const FunctionDef* fdef; + TF_RETURN_IF_ERROR( + LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef)); + + Device* current_device = ctx->flr()->device(); + DeviceType current_device_type(current_device->device_type()); + DeviceNameUtils::ParsedName current_device_name; + if (!DeviceNameUtils::ParseFullName(current_device->name(), + ¤t_device_name)) { + return errors::InvalidArgument("Failed to parse device name: ", + current_device->name()); + } + + // Check if any of the captured inputs are placed on a device not compatible + // with the current device. For non-captured inputs, we assume they are placed + // on the current device. + for (const auto& input : captured_inputs_) { + DataType dtype = input.dtype(); + if (dtype == DT_RESOURCE) { + const ResourceHandle& handle = input.flat()(0); + DeviceNameUtils::ParsedName resource_device_name; + if (!DeviceNameUtils::ParseFullName(handle.device(), + &resource_device_name)) { + return errors::InvalidArgument("Failed to parse device name: ", + handle.device()); + } + if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, + resource_device_name)) { + *is_multi_device = true; + return Status::OK(); + } + } + } + + // Check if all ops could be placed on the current device. + for (const auto& name : metadata_->lib_def()->ListFunctionNames()) { + const FunctionDef* fdef; + TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef)); + for (const auto& node : fdef->node_def()) { + // Check if the op has a kernel available for the current device. + if (!KernelDefAvailable(current_device_type, node)) { + *is_multi_device = true; + return Status::OK(); + } + // If the op has a requested device, check if the requested device is + // compatible with the current device. + if (!node.device().empty()) { + DeviceNameUtils::ParsedName node_device_name; + if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) { + return errors::InvalidArgument("Failed to parse device name: ", + node.device()); + } + if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, + node_device_name)) { + *is_multi_device = true; + return Status::OK(); + } + } + } + } + + *is_multi_device = false; + return Status::OK(); +} + +/* static */ +Status InstantiatedCapturedFunction::Create( + FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, + DataTypeVector ret_types, std::function)> runner, + CapturedFunction* captured_func, bool is_multi_device, + std::unique_ptr* out_function) { + out_function->reset(new InstantiatedCapturedFunction( + lib, f_handle, ret_types, runner, captured_func, is_multi_device)); + return Status::OK(); +} + InstantiatedCapturedFunction::InstantiatedCapturedFunction( FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, DataTypeVector ret_types, std::function)> runner, @@ -678,13 +763,6 @@ InstantiatedCapturedFunction::InstantiatedCapturedFunction( captured_func_(captured_func), is_multi_device_(is_multi_device) {} -// NOTE: We don't release f_handle_ here and instead delegate the function -// handle releasing to the FunctionHandleCache. This is because in some cases -// (RepeatDatasetOp in particular), we want to keep the function state (e.g. -// random number generator) even after the Iterator is reset after going through -// one epoch. -InstantiatedCapturedFunction::~InstantiatedCapturedFunction() {} - Status InstantiatedCapturedFunction::Run(IteratorContext* ctx, std::vector&& args, std::vector* rets) const { @@ -884,11 +962,5 @@ bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const { return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_; } -CapturedFunction::CapturedFunction( - std::shared_ptr metadata, - std::vector captured_inputs) - : metadata_(std::move(metadata)), - captured_inputs_(std::move(captured_inputs)) {} - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index de424fc547c..68b3ea552fc 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -52,77 +52,6 @@ Status MakeIteratorFromInputElement( Status IsNodeStateful(const FunctionLibraryDefinition& library, const NodeDef& node); -// `InstantiatedCapturedFunction` encapsulates all the runtime support needed -// to execute a tensorflow function. -// -// While `CapturedFunction` encapsulates constant attributes of the function, -// such as its name and captured arguments, `InstantiatedCapturedFunction` -// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function -// handle. -// -// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute -// functions outside of the normal `OpKernel::Compute()` context. -class InstantiatedCapturedFunction { - public: - ~InstantiatedCapturedFunction(); - - // Runs the instantiated captured function. This method takes ownership of - // the tensors in `args`, in order to be able to deallocate them as early as - // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain - // ownership of the `args`. - Status Run(IteratorContext* ctx, std::vector&& args, - std::vector* rets) const; - - // Synchronously runs the captured function on the given `args`, and stores - // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when - // possible. - Status RunWithBorrowedArgs(IteratorContext* ctx, - const std::vector& args, - std::vector* rets) const; - - // Synchronously runs the captured function on the given `args`, and stores - // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when - // possible. This can be useful for calling a captured - // function in cases where an `IteratorContext*` is not available - // (such as a destructor). - // TODO(b/144278100): Avoid running functions without IteratorContext. - Status RunInstantiated(const std::vector& args, - std::vector* rets); - - // Asynchronously runs the captured function on the given `args`, stores - // the results in `*rets`, and calls the given `done` callback when the - // function returns. This method takes ownership of the tensors in `args`, - // in order to be able to deallocate them as early as possible. - void RunAsync(IteratorContext* ctx, std::vector&& args, - std::vector* rets, - FunctionLibraryRuntime::DoneCallback done, - const std::shared_ptr& node) const; - - private: - InstantiatedCapturedFunction( - FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, - DataTypeVector ret_types, - std::function)> runner, - CapturedFunction* captured_func, bool is_multi_device); - - // Determines whether a rendezvous object should be created when running the - // instantiated function. - bool ShouldCreateRendezvous() const; - - friend class CapturedFunction; - - FunctionLibraryRuntime* const lib_; - const FunctionLibraryRuntime::Handle f_handle_; - const DataTypeVector ret_types_; - // Note: We capture the runner at function instantiation time to be able to - // run the function without `IteratorContext` via `RunInstantiated`. - std::function)> captured_runner_; - CapturedFunction* const captured_func_; - bool const is_multi_device_; - - TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); -}; - struct ShortCircuitInfo { std::vector indices; std::vector can_move; @@ -217,12 +146,6 @@ class CapturedFunction { std::unique_ptr* instantiated_captured_function); - // Determines whether the captured function is stateful. - // - // TODO(jsimsa): Remove this method once all users of `CapturedFunction` - // migrate to `CheckExternalState`. - bool IsStateful() const; - // Determines whether the captured function is stateful. Status CheckExternalState() const; @@ -256,11 +179,90 @@ class CapturedFunction { CapturedFunction(std::shared_ptr metadata, std::vector captured_inputs); + Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const; + const std::shared_ptr metadata_; const std::vector captured_inputs_; TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; + +// `InstantiatedCapturedFunction` encapsulates all the runtime support needed +// to execute a tensorflow function. +// +// While `CapturedFunction` encapsulates constant attributes of the function, +// such as its name and captured arguments, `InstantiatedCapturedFunction` +// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function +// handle. +// +// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute +// functions outside of the normal `OpKernel::Compute()` context. +class InstantiatedCapturedFunction { + public: + // Creates a new instance of the `InstantiatedCapturedFunction` class from the + // given inputs. + static Status Create( + FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, + DataTypeVector ret_types, + std::function)> runner, + CapturedFunction* captured_func, bool is_multi_device, + std::unique_ptr* out_function); + + // Runs the instantiated captured function. This method takes ownership of + // the tensors in `args`, in order to be able to deallocate them as early as + // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain + // ownership of the `args`. + Status Run(IteratorContext* ctx, std::vector&& args, + std::vector* rets) const; + + // Synchronously runs the captured function on the given `args`, and stores + // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when + // possible. + Status RunWithBorrowedArgs(IteratorContext* ctx, + const std::vector& args, + std::vector* rets) const; + + // Synchronously runs the captured function on the given `args`, and stores + // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when + // possible. This can be useful for calling a captured function in cases where + // an `IteratorContext*` is not available (such as a destructor). + // + // TODO(b/144278100): Avoid running functions without IteratorContext. + Status RunInstantiated(const std::vector& args, + std::vector* rets); + + // Asynchronously runs the captured function on the given `args`, stores the + // results in `*rets`, and calls the given `done` callback when the function + // returns. This method takes ownership of the tensors in `args`, in order to + // be able to deallocate them as early as possible. + void RunAsync(IteratorContext* ctx, std::vector&& args, + std::vector* rets, + FunctionLibraryRuntime::DoneCallback done, + const std::shared_ptr& node) const; + + private: + InstantiatedCapturedFunction( + FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, + DataTypeVector ret_types, + std::function)> runner, + CapturedFunction* captured_func, bool is_multi_device); + + // Determines whether a rendezvous object should be created when running the + // instantiated function. + bool ShouldCreateRendezvous() const; + + FunctionLibraryRuntime* const lib_; // Not owned. + const FunctionLibraryRuntime::Handle f_handle_; + const DataTypeVector ret_types_; + // Note: We capture the runner at function instantiation time to be able to + // run the function without `IteratorContext` via `RunInstantiated`. + std::function)> captured_runner_; + CapturedFunction* const captured_func_; // Not owned. + const bool is_multi_device_; + + TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); +}; + } // namespace data // TODO(b/114112161): Remove these aliases when all users have moved over to the