[tf.data] Cleanup of C++ user-defined function execution utilities.

This CL introduces a distinction between whether a function is multi-device (`is_multi_device`) and when a multi-device function backend is used to instantiate and execute the function. The difference is needed because the multi-device backend will not create a rendezvous for single device functions and if the function is not executed on CPU, it needs one.

PiperOrigin-RevId: 315998784
Change-Id: I8aebad07240298b6d21c0d7126b271f9952b264d
This commit is contained in:
Jiri Simsa 2020-06-11 16:15:08 -07:00 committed by TensorFlower Gardener
parent 7933b0e5e3
commit 151ad23522
2 changed files with 174 additions and 100 deletions

View File

@ -565,8 +565,7 @@ Status CapturedFunction::Instantiate(
if (!metadata_->use_inter_op_parallelism()) {
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
}
bool is_multi_device = metadata_->use_multi_device_function();
inst_opts.is_multi_device_function = is_multi_device;
inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
// We infer the target device from the function library runtime.
DCHECK(lib->device() != nullptr);
@ -649,16 +648,13 @@ Status CapturedFunction::Instantiate(
DataTypeVector ret_types;
TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
*instantiated_captured_function =
absl::WrapUnique<InstantiatedCapturedFunction>(
new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
*ctx->runner(), this,
is_multi_device));
return Status::OK();
bool is_multi_device;
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
return InstantiatedCapturedFunction::Create(
lib, f_handle, std::move(ret_types), *ctx->runner(), this,
is_multi_device, instantiated_captured_function);
}
bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); }
Status CapturedFunction::CheckExternalState() const {
for (const auto& name : lib_def()->ListFunctionNames()) {
TF_RETURN_IF_ERROR(
@ -667,6 +663,95 @@ Status CapturedFunction::CheckExternalState() const {
return Status::OK();
}
CapturedFunction::CapturedFunction(
std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs)
: metadata_(std::move(metadata)),
captured_inputs_(std::move(captured_inputs)) {}
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
bool* is_multi_device) const {
if (!metadata_->use_multi_device_function()) {
*is_multi_device = false;
return Status::OK();
}
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
Device* current_device = ctx->flr()->device();
DeviceType current_device_type(current_device->device_type());
DeviceNameUtils::ParsedName current_device_name;
if (!DeviceNameUtils::ParseFullName(current_device->name(),
&current_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
current_device->name());
}
// Check if any of the captured inputs are placed on a device not compatible
// with the current device. For non-captured inputs, we assume they are placed
// on the current device.
for (const auto& input : captured_inputs_) {
DataType dtype = input.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
DeviceNameUtils::ParsedName resource_device_name;
if (!DeviceNameUtils::ParseFullName(handle.device(),
&resource_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
handle.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
resource_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
// Check if all ops could be placed on the current device.
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
for (const auto& node : fdef->node_def()) {
// Check if the op has a kernel available for the current device.
if (!KernelDefAvailable(current_device_type, node)) {
*is_multi_device = true;
return Status::OK();
}
// If the op has a requested device, check if the requested device is
// compatible with the current device.
if (!node.device().empty()) {
DeviceNameUtils::ParsedName node_device_name;
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
node.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
node_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
}
*is_multi_device = false;
return Status::OK();
}
/* static */
Status InstantiatedCapturedFunction::Create(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
CapturedFunction* captured_func, bool is_multi_device,
std::unique_ptr<InstantiatedCapturedFunction>* out_function) {
out_function->reset(new InstantiatedCapturedFunction(
lib, f_handle, ret_types, runner, captured_func, is_multi_device));
return Status::OK();
}
InstantiatedCapturedFunction::InstantiatedCapturedFunction(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
@ -678,13 +763,6 @@ InstantiatedCapturedFunction::InstantiatedCapturedFunction(
captured_func_(captured_func),
is_multi_device_(is_multi_device) {}
// NOTE: We don't release f_handle_ here and instead delegate the function
// handle releasing to the FunctionHandleCache. This is because in some cases
// (RepeatDatasetOp in particular), we want to keep the function state (e.g.
// random number generator) even after the Iterator is reset after going through
// one epoch.
InstantiatedCapturedFunction::~InstantiatedCapturedFunction() {}
Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const {
@ -884,11 +962,5 @@ bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
}
CapturedFunction::CapturedFunction(
std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs)
: metadata_(std::move(metadata)),
captured_inputs_(std::move(captured_inputs)) {}
} // namespace data
} // namespace tensorflow

View File

@ -52,77 +52,6 @@ Status MakeIteratorFromInputElement(
Status IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node);
// `InstantiatedCapturedFunction` encapsulates all the runtime support needed
// to execute a tensorflow function.
//
// While `CapturedFunction` encapsulates constant attributes of the function,
// such as its name and captured arguments, `InstantiatedCapturedFunction`
// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
// handle.
//
// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
// functions outside of the normal `OpKernel::Compute()` context.
class InstantiatedCapturedFunction {
public:
~InstantiatedCapturedFunction();
// Runs the instantiated captured function. This method takes ownership of
// the tensors in `args`, in order to be able to deallocate them as early as
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
// ownership of the `args`.
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible.
Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible. This can be useful for calling a captured
// function in cases where an `IteratorContext*` is not available
// (such as a destructor).
// TODO(b/144278100): Avoid running functions without IteratorContext.
Status RunInstantiated(const std::vector<Tensor>& args,
std::vector<Tensor>* rets);
// Asynchronously runs the captured function on the given `args`, stores
// the results in `*rets`, and calls the given `done` callback when the
// function returns. This method takes ownership of the tensors in `args`,
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
const std::shared_ptr<model::Node>& node) const;
private:
InstantiatedCapturedFunction(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types,
std::function<void(std::function<void()>)> runner,
CapturedFunction* captured_func, bool is_multi_device);
// Determines whether a rendezvous object should be created when running the
// instantiated function.
bool ShouldCreateRendezvous() const;
friend class CapturedFunction;
FunctionLibraryRuntime* const lib_;
const FunctionLibraryRuntime::Handle f_handle_;
const DataTypeVector ret_types_;
// Note: We capture the runner at function instantiation time to be able to
// run the function without `IteratorContext` via `RunInstantiated`.
std::function<void(std::function<void()>)> captured_runner_;
CapturedFunction* const captured_func_;
bool const is_multi_device_;
TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
};
struct ShortCircuitInfo {
std::vector<int> indices;
std::vector<bool> can_move;
@ -217,12 +146,6 @@ class CapturedFunction {
std::unique_ptr<InstantiatedCapturedFunction>*
instantiated_captured_function);
// Determines whether the captured function is stateful.
//
// TODO(jsimsa): Remove this method once all users of `CapturedFunction`
// migrate to `CheckExternalState`.
bool IsStateful() const;
// Determines whether the captured function is stateful.
Status CheckExternalState() const;
@ -256,11 +179,90 @@ class CapturedFunction {
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs);
Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const;
const std::shared_ptr<const FunctionMetadata> metadata_;
const std::vector<Tensor> captured_inputs_;
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
};
// `InstantiatedCapturedFunction` encapsulates all the runtime support needed
// to execute a tensorflow function.
//
// While `CapturedFunction` encapsulates constant attributes of the function,
// such as its name and captured arguments, `InstantiatedCapturedFunction`
// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
// handle.
//
// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
// functions outside of the normal `OpKernel::Compute()` context.
class InstantiatedCapturedFunction {
public:
// Creates a new instance of the `InstantiatedCapturedFunction` class from the
// given inputs.
static Status Create(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types,
std::function<void(std::function<void()>)> runner,
CapturedFunction* captured_func, bool is_multi_device,
std::unique_ptr<InstantiatedCapturedFunction>* out_function);
// Runs the instantiated captured function. This method takes ownership of
// the tensors in `args`, in order to be able to deallocate them as early as
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
// ownership of the `args`.
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible.
Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible. This can be useful for calling a captured function in cases where
// an `IteratorContext*` is not available (such as a destructor).
//
// TODO(b/144278100): Avoid running functions without IteratorContext.
Status RunInstantiated(const std::vector<Tensor>& args,
std::vector<Tensor>* rets);
// Asynchronously runs the captured function on the given `args`, stores the
// results in `*rets`, and calls the given `done` callback when the function
// returns. This method takes ownership of the tensors in `args`, in order to
// be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
const std::shared_ptr<model::Node>& node) const;
private:
InstantiatedCapturedFunction(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types,
std::function<void(std::function<void()>)> runner,
CapturedFunction* captured_func, bool is_multi_device);
// Determines whether a rendezvous object should be created when running the
// instantiated function.
bool ShouldCreateRendezvous() const;
FunctionLibraryRuntime* const lib_; // Not owned.
const FunctionLibraryRuntime::Handle f_handle_;
const DataTypeVector ret_types_;
// Note: We capture the runner at function instantiation time to be able to
// run the function without `IteratorContext` via `RunInstantiated`.
std::function<void(std::function<void()>)> captured_runner_;
CapturedFunction* const captured_func_; // Not owned.
const bool is_multi_device_;
TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
};
} // namespace data
// TODO(b/114112161): Remove these aliases when all users have moved over to the