[Executor] Avoid unnecessary NodeItem::input_type() calls in PrepareInputs().

We currently unconditionally read the input type of all inputs in order to handle the increasingly rare (and deprecated!) reference-typed input case. This change caches whether there are any ref-typed inputs to a node in a bit in the fixed-length part of a `NodeItem`, and only evaluates the input type if that bit is true.

In addition, this change avoids calling `InlinedVector<TensorValue>::clear()` before calling `InlinedVector<TensorValue>::resize()`. Because we overwrite all of the values in the vector, there is no need to clear it before resizing. In some cases this can avoid a free/malloc of the underlying vector storage.

PiperOrigin-RevId: 310572347
Change-Id: Ie5be8eb32fd4eba522f3b661cf9f5099d5263c6f
This commit is contained in:
Derek Murray 2020-05-08 09:37:45 -07:00 committed by TensorFlower Gardener
parent 56a99d6c28
commit 4cb8816479
3 changed files with 18 additions and 10 deletions
tensorflow/core/common_runtime

View File

@ -811,16 +811,14 @@ template <class PropagatorStateType>
Status ExecutorState<PropagatorStateType>::PrepareInputs(
const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
inputs->clear();
inputs->resize(item.num_inputs);
input_alloc_attrs->clear();
input_alloc_attrs->resize(item.num_inputs);
*is_input_dead = false;
bool is_merge = item.is_merge;
for (int i = 0; i < item.num_inputs; ++i) {
const bool expect_ref = IsRefType(item.input_type(i));
const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
IsRefType(item.input_type(i));
Entry* entry = first_input + i;
(*input_alloc_attrs)[i] = entry->alloc_attr;
@ -830,7 +828,10 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
switch (entry->state) {
case Entry::State::NO_VALUE: {
// Only merge and transfer nodes can have no-value inputs.
if (!is_merge) {
inp->mutex_if_ref = nullptr;
if (item.is_merge) {
inp->tensor = nullptr;
} else {
DCHECK(item.is_transfer_node)
<< item.kernel->name() << " - input " << i;
entry->state = Entry::State::HAS_CONST_TENSOR;
@ -846,17 +847,18 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
}
case Entry::State::HAS_VALUE: {
if (expect_ref) {
if (TF_PREDICT_FALSE(expect_ref)) {
return AttachDef(
errors::InvalidArgument(i, "-th input expects a ref type"),
item.kernel->def());
}
inp->mutex_if_ref = nullptr;
inp->tensor = entry->val.get();
break;
}
case Entry::State::HAS_CONST_TENSOR: {
if (expect_ref) {
if (TF_PREDICT_FALSE(expect_ref)) {
return AttachDef(
errors::InvalidArgument(i, "-th input expects a ref type"),
item.kernel->def());
@ -865,6 +867,7 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
// stores a non-const `Tensor*`, and relies on the `OpKernelContext`
// accessors making dynamic checks that prevent using an immutable
// tensor as a mutable tensor.
inp->mutex_if_ref = nullptr;
inp->tensor = const_cast<Tensor*>(entry->const_tensor);
break;
}
@ -872,8 +875,8 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
case Entry::State::HAS_REF_TENSOR: {
{
tf_shared_lock ml(*entry->ref_tensor.mu);
if (!entry->ref_tensor.tensor->IsInitialized() &&
!item.is_initialization_op) {
if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
!item.is_initialization_op)) {
return AttachDef(errors::FailedPrecondition(
"Attempting to use uninitialized value ",
item.kernel->requested_input(i)),
@ -896,12 +899,13 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
}
entry->state = Entry::State::HAS_VALUE;
inp->mutex_if_ref = nullptr;
inp->tensor = entry->val.get();
// The dtype of entry->ref_tensor.tensor could have been changed by
// another operation that ran after the operation that "produced" it
// executed, so re-validate that the type of the dereferenced tensor
// matches the expected input type.
if (item.input_type(i) != inp->tensor->dtype()) {
if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
return AttachDef(
errors::InvalidArgument(
i, "-th input expects type ",

View File

@ -191,9 +191,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
uint8* input_types = item->input_type_base();
item->is_any_input_ref_typed = false;
for (int i = 0; i < num_inputs; i++) {
input_types[i] = static_cast<uint8>(n->input_type(i));
DCHECK_EQ(item->input_type(i), n->input_type(i));
item->is_any_input_ref_typed |= IsRefType(n->input_type(i));
}
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.

View File

@ -81,6 +81,8 @@ struct NodeItem {
// of any output edge is a
// merge or control trigger
// node.
bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this
// node's input types.
// The kernel for this node.
OpKernel* kernel = nullptr;