[tf.data] Cleanup of C++ user-defined function execution utilities.
This CL introduces a distinction between whether a function is multi-device (`is_multi_device`) and when a multi-device function backend is used to instantiate and execute the function. The difference is needed because the multi-device backend will not create a rendezvous for single device functions and if the function is not executed on CPU, it needs one. PiperOrigin-RevId: 315998784 Change-Id: I8aebad07240298b6d21c0d7126b271f9952b264d
This commit is contained in:
parent
7933b0e5e3
commit
151ad23522
@ -565,8 +565,7 @@ Status CapturedFunction::Instantiate(
|
||||
if (!metadata_->use_inter_op_parallelism()) {
|
||||
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
|
||||
}
|
||||
bool is_multi_device = metadata_->use_multi_device_function();
|
||||
inst_opts.is_multi_device_function = is_multi_device;
|
||||
inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
|
||||
|
||||
// We infer the target device from the function library runtime.
|
||||
DCHECK(lib->device() != nullptr);
|
||||
@ -649,16 +648,13 @@ Status CapturedFunction::Instantiate(
|
||||
DataTypeVector ret_types;
|
||||
TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
|
||||
|
||||
*instantiated_captured_function =
|
||||
absl::WrapUnique<InstantiatedCapturedFunction>(
|
||||
new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
|
||||
*ctx->runner(), this,
|
||||
is_multi_device));
|
||||
return Status::OK();
|
||||
bool is_multi_device;
|
||||
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
|
||||
return InstantiatedCapturedFunction::Create(
|
||||
lib, f_handle, std::move(ret_types), *ctx->runner(), this,
|
||||
is_multi_device, instantiated_captured_function);
|
||||
}
|
||||
|
||||
bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); }
|
||||
|
||||
Status CapturedFunction::CheckExternalState() const {
|
||||
for (const auto& name : lib_def()->ListFunctionNames()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -667,6 +663,95 @@ Status CapturedFunction::CheckExternalState() const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CapturedFunction::CapturedFunction(
|
||||
std::shared_ptr<const FunctionMetadata> metadata,
|
||||
std::vector<Tensor> captured_inputs)
|
||||
: metadata_(std::move(metadata)),
|
||||
captured_inputs_(std::move(captured_inputs)) {}
|
||||
|
||||
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
|
||||
bool* is_multi_device) const {
|
||||
if (!metadata_->use_multi_device_function()) {
|
||||
*is_multi_device = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const FunctionDef* fdef;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
|
||||
|
||||
Device* current_device = ctx->flr()->device();
|
||||
DeviceType current_device_type(current_device->device_type());
|
||||
DeviceNameUtils::ParsedName current_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(current_device->name(),
|
||||
¤t_device_name)) {
|
||||
return errors::InvalidArgument("Failed to parse device name: ",
|
||||
current_device->name());
|
||||
}
|
||||
|
||||
// Check if any of the captured inputs are placed on a device not compatible
|
||||
// with the current device. For non-captured inputs, we assume they are placed
|
||||
// on the current device.
|
||||
for (const auto& input : captured_inputs_) {
|
||||
DataType dtype = input.dtype();
|
||||
if (dtype == DT_RESOURCE) {
|
||||
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
|
||||
DeviceNameUtils::ParsedName resource_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(handle.device(),
|
||||
&resource_device_name)) {
|
||||
return errors::InvalidArgument("Failed to parse device name: ",
|
||||
handle.device());
|
||||
}
|
||||
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
|
||||
resource_device_name)) {
|
||||
*is_multi_device = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all ops could be placed on the current device.
|
||||
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
|
||||
const FunctionDef* fdef;
|
||||
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
|
||||
for (const auto& node : fdef->node_def()) {
|
||||
// Check if the op has a kernel available for the current device.
|
||||
if (!KernelDefAvailable(current_device_type, node)) {
|
||||
*is_multi_device = true;
|
||||
return Status::OK();
|
||||
}
|
||||
// If the op has a requested device, check if the requested device is
|
||||
// compatible with the current device.
|
||||
if (!node.device().empty()) {
|
||||
DeviceNameUtils::ParsedName node_device_name;
|
||||
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
|
||||
return errors::InvalidArgument("Failed to parse device name: ",
|
||||
node.device());
|
||||
}
|
||||
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
|
||||
node_device_name)) {
|
||||
*is_multi_device = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*is_multi_device = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status InstantiatedCapturedFunction::Create(
|
||||
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
|
||||
DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
|
||||
CapturedFunction* captured_func, bool is_multi_device,
|
||||
std::unique_ptr<InstantiatedCapturedFunction>* out_function) {
|
||||
out_function->reset(new InstantiatedCapturedFunction(
|
||||
lib, f_handle, ret_types, runner, captured_func, is_multi_device));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
InstantiatedCapturedFunction::InstantiatedCapturedFunction(
|
||||
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
|
||||
DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
|
||||
@ -678,13 +763,6 @@ InstantiatedCapturedFunction::InstantiatedCapturedFunction(
|
||||
captured_func_(captured_func),
|
||||
is_multi_device_(is_multi_device) {}
|
||||
|
||||
// NOTE: We don't release f_handle_ here and instead delegate the function
|
||||
// handle releasing to the FunctionHandleCache. This is because in some cases
|
||||
// (RepeatDatasetOp in particular), we want to keep the function state (e.g.
|
||||
// random number generator) even after the Iterator is reset after going through
|
||||
// one epoch.
|
||||
InstantiatedCapturedFunction::~InstantiatedCapturedFunction() {}
|
||||
|
||||
Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets) const {
|
||||
@ -884,11 +962,5 @@ bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
|
||||
return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
|
||||
}
|
||||
|
||||
CapturedFunction::CapturedFunction(
|
||||
std::shared_ptr<const FunctionMetadata> metadata,
|
||||
std::vector<Tensor> captured_inputs)
|
||||
: metadata_(std::move(metadata)),
|
||||
captured_inputs_(std::move(captured_inputs)) {}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -52,77 +52,6 @@ Status MakeIteratorFromInputElement(
|
||||
Status IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
const NodeDef& node);
|
||||
|
||||
// `InstantiatedCapturedFunction` encapsulates all the runtime support needed
|
||||
// to execute a tensorflow function.
|
||||
//
|
||||
// While `CapturedFunction` encapsulates constant attributes of the function,
|
||||
// such as its name and captured arguments, `InstantiatedCapturedFunction`
|
||||
// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
|
||||
// handle.
|
||||
//
|
||||
// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
|
||||
// functions outside of the normal `OpKernel::Compute()` context.
|
||||
class InstantiatedCapturedFunction {
|
||||
public:
|
||||
~InstantiatedCapturedFunction();
|
||||
|
||||
// Runs the instantiated captured function. This method takes ownership of
|
||||
// the tensors in `args`, in order to be able to deallocate them as early as
|
||||
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
|
||||
// ownership of the `args`.
|
||||
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible.
|
||||
Status RunWithBorrowedArgs(IteratorContext* ctx,
|
||||
const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible. This can be useful for calling a captured
|
||||
// function in cases where an `IteratorContext*` is not available
|
||||
// (such as a destructor).
|
||||
// TODO(b/144278100): Avoid running functions without IteratorContext.
|
||||
Status RunInstantiated(const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets);
|
||||
|
||||
// Asynchronously runs the captured function on the given `args`, stores
|
||||
// the results in `*rets`, and calls the given `done` callback when the
|
||||
// function returns. This method takes ownership of the tensors in `args`,
|
||||
// in order to be able to deallocate them as early as possible.
|
||||
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done,
|
||||
const std::shared_ptr<model::Node>& node) const;
|
||||
|
||||
private:
|
||||
InstantiatedCapturedFunction(
|
||||
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
|
||||
DataTypeVector ret_types,
|
||||
std::function<void(std::function<void()>)> runner,
|
||||
CapturedFunction* captured_func, bool is_multi_device);
|
||||
|
||||
// Determines whether a rendezvous object should be created when running the
|
||||
// instantiated function.
|
||||
bool ShouldCreateRendezvous() const;
|
||||
|
||||
friend class CapturedFunction;
|
||||
|
||||
FunctionLibraryRuntime* const lib_;
|
||||
const FunctionLibraryRuntime::Handle f_handle_;
|
||||
const DataTypeVector ret_types_;
|
||||
// Note: We capture the runner at function instantiation time to be able to
|
||||
// run the function without `IteratorContext` via `RunInstantiated`.
|
||||
std::function<void(std::function<void()>)> captured_runner_;
|
||||
CapturedFunction* const captured_func_;
|
||||
bool const is_multi_device_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
|
||||
};
|
||||
|
||||
struct ShortCircuitInfo {
|
||||
std::vector<int> indices;
|
||||
std::vector<bool> can_move;
|
||||
@ -217,12 +146,6 @@ class CapturedFunction {
|
||||
std::unique_ptr<InstantiatedCapturedFunction>*
|
||||
instantiated_captured_function);
|
||||
|
||||
// Determines whether the captured function is stateful.
|
||||
//
|
||||
// TODO(jsimsa): Remove this method once all users of `CapturedFunction`
|
||||
// migrate to `CheckExternalState`.
|
||||
bool IsStateful() const;
|
||||
|
||||
// Determines whether the captured function is stateful.
|
||||
Status CheckExternalState() const;
|
||||
|
||||
@ -256,11 +179,90 @@ class CapturedFunction {
|
||||
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
|
||||
std::vector<Tensor> captured_inputs);
|
||||
|
||||
Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const;
|
||||
|
||||
const std::shared_ptr<const FunctionMetadata> metadata_;
|
||||
const std::vector<Tensor> captured_inputs_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
|
||||
};
|
||||
|
||||
// `InstantiatedCapturedFunction` encapsulates all the runtime support needed
|
||||
// to execute a tensorflow function.
|
||||
//
|
||||
// While `CapturedFunction` encapsulates constant attributes of the function,
|
||||
// such as its name and captured arguments, `InstantiatedCapturedFunction`
|
||||
// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
|
||||
// handle.
|
||||
//
|
||||
// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
|
||||
// functions outside of the normal `OpKernel::Compute()` context.
|
||||
class InstantiatedCapturedFunction {
|
||||
public:
|
||||
// Creates a new instance of the `InstantiatedCapturedFunction` class from the
|
||||
// given inputs.
|
||||
static Status Create(
|
||||
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
|
||||
DataTypeVector ret_types,
|
||||
std::function<void(std::function<void()>)> runner,
|
||||
CapturedFunction* captured_func, bool is_multi_device,
|
||||
std::unique_ptr<InstantiatedCapturedFunction>* out_function);
|
||||
|
||||
// Runs the instantiated captured function. This method takes ownership of
|
||||
// the tensors in `args`, in order to be able to deallocate them as early as
|
||||
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
|
||||
// ownership of the `args`.
|
||||
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible.
|
||||
Status RunWithBorrowedArgs(IteratorContext* ctx,
|
||||
const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible. This can be useful for calling a captured function in cases where
|
||||
// an `IteratorContext*` is not available (such as a destructor).
|
||||
//
|
||||
// TODO(b/144278100): Avoid running functions without IteratorContext.
|
||||
Status RunInstantiated(const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets);
|
||||
|
||||
// Asynchronously runs the captured function on the given `args`, stores the
|
||||
// results in `*rets`, and calls the given `done` callback when the function
|
||||
// returns. This method takes ownership of the tensors in `args`, in order to
|
||||
// be able to deallocate them as early as possible.
|
||||
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done,
|
||||
const std::shared_ptr<model::Node>& node) const;
|
||||
|
||||
private:
|
||||
InstantiatedCapturedFunction(
|
||||
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
|
||||
DataTypeVector ret_types,
|
||||
std::function<void(std::function<void()>)> runner,
|
||||
CapturedFunction* captured_func, bool is_multi_device);
|
||||
|
||||
// Determines whether a rendezvous object should be created when running the
|
||||
// instantiated function.
|
||||
bool ShouldCreateRendezvous() const;
|
||||
|
||||
FunctionLibraryRuntime* const lib_; // Not owned.
|
||||
const FunctionLibraryRuntime::Handle f_handle_;
|
||||
const DataTypeVector ret_types_;
|
||||
// Note: We capture the runner at function instantiation time to be able to
|
||||
// run the function without `IteratorContext` via `RunInstantiated`.
|
||||
std::function<void(std::function<void()>)> captured_runner_;
|
||||
CapturedFunction* const captured_func_; // Not owned.
|
||||
const bool is_multi_device_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
||||
// TODO(b/114112161): Remove these aliases when all users have moved over to the
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user