Avoid reference counting handles in sync mode

PiperOrigin-RevId: 280585445
Change-Id: I4197d5fdb25f9def94b83dccc1c43463407b10e5
This commit is contained in:
Gaurav Jain 2019-11-14 21:39:26 -08:00 committed by TensorFlower Gardener
parent 4f0d5ccd23
commit 919253c654
2 changed files with 25 additions and 16 deletions

View File

@ -667,10 +667,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
output_dtypes[i], ctx, &retvals[i]));
}
std::unique_ptr<EagerNode> node(
new ExecuteNode(ctx, op->Inputs(), op->remote_func_params(),
std::move(kernel), graph_collector, output_dtypes,
op->GetCancellationManager(), {retvals, num_outputs}));
std::unique_ptr<EagerNode> node(new ExecuteNode(
ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
graph_collector, output_dtypes, op->GetCancellationManager(),
executor.Async(), {retvals, num_outputs}));
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(b/137118203): Consider executing "cheap" kernels inline for

View File

@ -86,7 +86,7 @@ class ExecuteNode : public EagerNode {
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
core::RefCountPtr<KernelAndDevice> kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
CancellationManager* cancellation_manager,
CancellationManager* cancellation_manager, bool async,
absl::Span<TensorHandle*> retvals)
: EagerNode(),
ctx_(ctx),
@ -94,28 +94,36 @@ class ExecuteNode : public EagerNode {
remote_func_params_(remote_func_params),
kernel_(std::move(kernel)),
graph_collector_(graph_collector),
cancellation_manager_(cancellation_manager) {
cancellation_manager_(cancellation_manager),
async_(async) {
// Copy the output handles, since the container for them might get
// destroyed.
for (auto handle : retvals) {
handle->Ref();
retvals_.push_back(handle);
}
// This is required to ensure that the tensor handles stay alive across the
// execution.
for (auto handle : inputs_) {
handle->Ref();
if (async_) {
// This is required to ensure that the tensor handles stay alive across
// the execution.
for (auto handle : inputs_) {
handle->Ref();
}
for (auto handle : retvals_) {
handle->Ref();
}
}
}
~ExecuteNode() override {
for (auto handle : retvals_) {
handle->Unref();
}
if (async_) {
for (auto handle : retvals_) {
handle->Unref();
}
for (auto handle : inputs_) {
handle->Unref();
for (auto handle : inputs_) {
handle->Unref();
}
}
}
@ -151,6 +159,7 @@ class ExecuteNode : public EagerNode {
core::RefCountPtr<KernelAndDevice> kernel_;
GraphCollector* graph_collector_;
CancellationManager* const cancellation_manager_;
const bool async_;
gtl::InlinedVector<TensorHandle*, 2> retvals_;
};