Use TensorReferenceVector in EagerKernelExecute in execute.cc instead of std::vector<tensorflow::Tensor>, and reorder input_vector filling and saving of protected_tensor phases to avoid running the second loop if the first loop determines it's not needed.
PiperOrigin-RevId: 268619047
This commit is contained in:
parent
ca8929ef56
commit
eb539185b9
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/logging.h"
|
#include "tensorflow/core/framework/logging.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_reference.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
@ -952,18 +953,25 @@ Status EagerKernelExecute(EagerContext* ctx,
|
|||||||
// overwritten during kernel execution. The reference count is incremented
|
// overwritten during kernel execution. The reference count is incremented
|
||||||
// below when we insert a copy of the Tensor into protected_tensors, and will
|
// below when we insert a copy of the Tensor into protected_tensors, and will
|
||||||
// be decremented once execution is complete.
|
// be decremented once execution is complete.
|
||||||
std::vector<tensorflow::Tensor> protected_tensors;
|
int first_index_that_needs_protecting = -1;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
|
||||||
|
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||||
|
TensorHandle* in = op_inputs[i];
|
||||||
|
TF_RETURN_IF_ERROR(in->TensorValue(&input_vector[i]));
|
||||||
|
if (first_index_that_needs_protecting < 0 && !in->RefCountIsOne()) {
|
||||||
|
first_index_that_needs_protecting = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorReferenceVector protected_tensors;
|
||||||
|
if (first_index_that_needs_protecting >= 0) {
|
||||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
for (int i = 0; i < op_inputs.size(); ++i) {
|
||||||
if (!op_inputs[i]->RefCountIsOne()) {
|
if (!op_inputs[i]->RefCountIsOne()) {
|
||||||
const Tensor* input_tensor = nullptr;
|
const Tensor* input_tensor = nullptr;
|
||||||
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
|
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
|
||||||
protected_tensors.push_back(*input_tensor);
|
protected_tensors.emplace_back(TensorReference(*input_tensor));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gtl::InlinedVector<TensorValue, 4> input_vector(op_inputs.size());
|
|
||||||
for (int i = 0; i < op_inputs.size(); ++i) {
|
|
||||||
TF_RETURN_IF_ERROR(op_inputs[i]->TensorValue(&input_vector[i]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(apassos) figure out how to record stats for ops which are a part of
|
// TODO(apassos) figure out how to record stats for ops which are a part of
|
||||||
@ -974,15 +982,18 @@ Status EagerKernelExecute(EagerContext* ctx,
|
|||||||
// device. We don't call it now because it is an unneeded overhead (it
|
// device. We don't call it now because it is an unneeded overhead (it
|
||||||
// acquires a lock) and we can't recover from errors anyway.
|
// acquires a lock) and we can't recover from errors anyway.
|
||||||
ScopedStepContainer* container = ctx->StepContainer();
|
ScopedStepContainer* container = ctx->StepContainer();
|
||||||
|
Status s;
|
||||||
if (container == nullptr) {
|
if (container == nullptr) {
|
||||||
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, maybe_stats,
|
s = kernel->Run(input_vector, &outputs, maybe_stats, maybe_step_stats,
|
||||||
maybe_step_stats, graph_collector,
|
graph_collector, cancellation_manager);
|
||||||
cancellation_manager));
|
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(kernel->Run(container, input_vector, &outputs,
|
s = kernel->Run(container, input_vector, &outputs, maybe_stats,
|
||||||
maybe_stats, maybe_step_stats,
|
maybe_step_stats, graph_collector, cancellation_manager);
|
||||||
graph_collector, cancellation_manager));
|
|
||||||
}
|
}
|
||||||
|
for (const auto& tensor_ref : protected_tensors) {
|
||||||
|
tensor_ref.Unref();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
if (graph_collector != nullptr) {
|
if (graph_collector != nullptr) {
|
||||||
mutex_lock ml(*ctx->MetadataMu());
|
mutex_lock ml(*ctx->MetadataMu());
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user