Use TensorReferenceVector in EagerKernelExecute in execute.cc instead of std::vector<tensorflow::Tensor>, and reorder input_vector filling and saving of protected_tensor phases to avoid running the second loop if the first loop determines it's not needed.
PiperOrigin-RevId: 268619047
This commit is contained in:
parent
ca8929ef56
commit
eb539185b9
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/logging.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -952,18 +953,25 @@ Status EagerKernelExecute(EagerContext* ctx,
|
||||
// overwritten during kernel execution. The reference count is incremented
|
||||
// below when we insert a copy of the Tensor into protected_tensors, and will
|
||||
// be decremented once execution is complete.
|
||||
std::vector<tensorflow::Tensor> protected_tensors;
|
||||
int first_index_that_needs_protecting = -1;
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
|
||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||
if (!op_inputs[i]->RefCountIsOne()) {
|
||||
const Tensor* input_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
|
||||
protected_tensors.push_back(*input_tensor);
|
||||
TensorHandle* in = op_inputs[i];
|
||||
TF_RETURN_IF_ERROR(in->TensorValue(&input_vector[i]));
|
||||
if (first_index_that_needs_protecting < 0 && !in->RefCountIsOne()) {
|
||||
first_index_that_needs_protecting = i;
|
||||
}
|
||||
}
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
|
||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
|
||||
TensorReferenceVector protected_tensors;
|
||||
if (first_index_that_needs_protecting >= 0) {
|
||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||
if (!op_inputs[i]->RefCountIsOne()) {
|
||||
const Tensor* input_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
|
||||
protected_tensors.emplace_back(TensorReference(*input_tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(apassos) figure out how to record stats for ops which are a part of
|
||||
@ -974,15 +982,18 @@ Status EagerKernelExecute(EagerContext* ctx,
|
||||
// device. We don't call it now because it is an unneeded overhead (it
|
||||
// acquires a lock) and we can't recover from errors anyway.
|
||||
ScopedStepContainer* container = ctx->StepContainer();
|
||||
Status s;
|
||||
if (container == nullptr) {
|
||||
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
|
||||
maybe_step_stats, graph_collector,
|
||||
cancellation_manager));
|
||||
s = kernel->Run(input_vector, &outputs, maybe_stats, maybe_step_stats,
|
||||
graph_collector, cancellation_manager);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
|
||||
maybe_stats, maybe_step_stats,
|
||||
graph_collector, cancellation_manager));
|
||||
s = kernel->Run(container, input_vector, &outputs, maybe_stats,
|
||||
maybe_step_stats, graph_collector, cancellation_manager);
|
||||
}
|
||||
for (const auto& tensor_ref : protected_tensors) {
|
||||
tensor_ref.Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
if (graph_collector != nullptr) {
|
||||
mutex_lock ml(*ctx->MetadataMu());
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user