Avoid reference counting handles in sync mode
PiperOrigin-RevId: 280585445 Change-Id: I4197d5fdb25f9def94b83dccc1c43463407b10e5
This commit is contained in:
parent
4f0d5ccd23
commit
919253c654
@ -667,10 +667,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
output_dtypes[i], ctx, &retvals[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<EagerNode> node(
|
||||
new ExecuteNode(ctx, op->Inputs(), op->remote_func_params(),
|
||||
std::move(kernel), graph_collector, output_dtypes,
|
||||
op->GetCancellationManager(), {retvals, num_outputs}));
|
||||
std::unique_ptr<EagerNode> node(new ExecuteNode(
|
||||
ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||
graph_collector, output_dtypes, op->GetCancellationManager(),
|
||||
executor.Async(), {retvals, num_outputs}));
|
||||
// Note that for async mode, execution order will make sure that all
|
||||
// input handles are ready before executing them.
|
||||
// TODO(b/137118203): Consider executing "cheap" kernels inline for
|
||||
|
||||
@ -86,7 +86,7 @@ class ExecuteNode : public EagerNode {
|
||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||
core::RefCountPtr<KernelAndDevice> kernel,
|
||||
GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
|
||||
CancellationManager* cancellation_manager,
|
||||
CancellationManager* cancellation_manager, bool async,
|
||||
absl::Span<TensorHandle*> retvals)
|
||||
: EagerNode(),
|
||||
ctx_(ctx),
|
||||
@ -94,28 +94,36 @@ class ExecuteNode : public EagerNode {
|
||||
remote_func_params_(remote_func_params),
|
||||
kernel_(std::move(kernel)),
|
||||
graph_collector_(graph_collector),
|
||||
cancellation_manager_(cancellation_manager) {
|
||||
cancellation_manager_(cancellation_manager),
|
||||
async_(async) {
|
||||
// Copy the output handles, since the container for them might get
|
||||
// destroyed.
|
||||
for (auto handle : retvals) {
|
||||
handle->Ref();
|
||||
retvals_.push_back(handle);
|
||||
}
|
||||
|
||||
// This is required to ensure that the tensor handles stay alive across the
|
||||
// execution.
|
||||
for (auto handle : inputs_) {
|
||||
handle->Ref();
|
||||
if (async_) {
|
||||
// This is required to ensure that the tensor handles stay alive across
|
||||
// the execution.
|
||||
for (auto handle : inputs_) {
|
||||
handle->Ref();
|
||||
}
|
||||
|
||||
for (auto handle : retvals_) {
|
||||
handle->Ref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~ExecuteNode() override {
|
||||
for (auto handle : retvals_) {
|
||||
handle->Unref();
|
||||
}
|
||||
if (async_) {
|
||||
for (auto handle : retvals_) {
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
for (auto handle : inputs_) {
|
||||
handle->Unref();
|
||||
for (auto handle : inputs_) {
|
||||
handle->Unref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,6 +159,7 @@ class ExecuteNode : public EagerNode {
|
||||
core::RefCountPtr<KernelAndDevice> kernel_;
|
||||
GraphCollector* graph_collector_;
|
||||
CancellationManager* const cancellation_manager_;
|
||||
const bool async_;
|
||||
gtl::InlinedVector<TensorHandle*, 2> retvals_;
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user