Use TensorReferenceVector in EagerKernelExecute in execute.cc instead of std::vector<tensorflow::Tensor>, and reorder input_vector filling and saving of protected_tensor phases to avoid running the second loop if the first loop determines it's not needed.

PiperOrigin-RevId: 268619047
This commit is contained in:
Yujing Zhang 2019-09-11 23:18:37 -07:00 committed by TensorFlower Gardener
parent ca8929ef56
commit eb539185b9

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -952,18 +953,25 @@ Status EagerKernelExecute(EagerContext* ctx,
// overwritten during kernel execution. The reference count is incremented
// below when we insert a copy of the Tensor into protected_tensors, and will
// be decremented once execution is complete.
std::vector<tensorflow::Tensor> protected_tensors;
int first_index_that_needs_protecting = -1;
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
for (int i = 0; i < op_inputs.size(); ++i) {
TensorHandle* in = op_inputs[i];
TF_RETURN_IF_ERROR(in->TensorValue(&input_vector[i]));
if (first_index_that_needs_protecting < 0 && !in->RefCountIsOne()) {
first_index_that_needs_protecting = i;
}
}
TensorReferenceVector protected_tensors;
if (first_index_that_needs_protecting >= 0) {
for (int i = 0; i < op_inputs.size(); ++i) {
if (!op_inputs[i]->RefCountIsOne()) {
const Tensor* input_tensor = nullptr;
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
protected_tensors.push_back(*input_tensor);
protected_tensors.emplace_back(TensorReference(*input_tensor));
}
}
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
for (int i = 0; i < op_inputs.size(); ++i) {
TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
}
// TODO(apassos) figure out how to record stats for ops which are a part of
@ -974,15 +982,18 @@ Status EagerKernelExecute(EagerContext* ctx,
// device. We don't call it now because it is an unneeded overhead (it
// acquires a lock) and we can't recover from errors anyway.
ScopedStepContainer* container = ctx->StepContainer();
Status s;
if (container == nullptr) {
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
maybe_step_stats, graph_collector,
cancellation_manager));
s = kernel->Run(input_vector, &outputs, maybe_stats, maybe_step_stats,
graph_collector, cancellation_manager);
} else {
TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
maybe_stats, maybe_step_stats,
graph_collector, cancellation_manager));
s = kernel->Run(container, input_vector, &outputs, maybe_stats,
maybe_step_stats, graph_collector, cancellation_manager);
}
for (const auto& tensor_ref : protected_tensors) {
tensor_ref.Unref();
}
TF_RETURN_IF_ERROR(s);
if (graph_collector != nullptr) {
mutex_lock ml(*ctx->MetadataMu());
{