[Executor] Avoid unnecessary NodeItem::input_type() calls in PrepareInputs().

We currently unconditionally read the input type of all inputs in order to handle the increasingly rare (and deprecated!) reference-typed input case. This change caches whether there are any ref-typed inputs to a node in a bit in the fixed-length part of a `NodeItem`, and only evaluates the input type if that bit is true.

In addition, this change avoids calling `InlinedVector<TensorValue>::clear()` before calling `InlinedVector<TensorValue>::resize()`. Because we overwrite all of the values in the vector, there is no need to clear it before resizing. In some cases this can avoid a free/malloc of the underlying vector storage.

PiperOrigin-RevId: 310572347
Change-Id: Ie5be8eb32fd4eba522f3b661cf9f5099d5263c6f
This commit is contained in:
Derek Murray 2020-05-08 09:37:45 -07:00 committed by TensorFlower Gardener
parent 56a99d6c28
commit 4cb8816479
3 changed files with 18 additions and 10 deletions

View File

@ -811,16 +811,14 @@ template <class PropagatorStateType>
Status ExecutorState<PropagatorStateType>::PrepareInputs( Status ExecutorState<PropagatorStateType>::PrepareInputs(
const NodeItem& item, Entry* first_input, TensorValueVec* inputs, const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
inputs->clear();
inputs->resize(item.num_inputs); inputs->resize(item.num_inputs);
input_alloc_attrs->clear();
input_alloc_attrs->resize(item.num_inputs); input_alloc_attrs->resize(item.num_inputs);
*is_input_dead = false; *is_input_dead = false;
bool is_merge = item.is_merge;
for (int i = 0; i < item.num_inputs; ++i) { for (int i = 0; i < item.num_inputs; ++i) {
const bool expect_ref = IsRefType(item.input_type(i)); const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
IsRefType(item.input_type(i));
Entry* entry = first_input + i; Entry* entry = first_input + i;
(*input_alloc_attrs)[i] = entry->alloc_attr; (*input_alloc_attrs)[i] = entry->alloc_attr;
@ -830,7 +828,10 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
switch (entry->state) { switch (entry->state) {
case Entry::State::NO_VALUE: { case Entry::State::NO_VALUE: {
// Only merge and transfer nodes can have no-value inputs. // Only merge and transfer nodes can have no-value inputs.
if (!is_merge) { inp->mutex_if_ref = nullptr;
if (item.is_merge) {
inp->tensor = nullptr;
} else {
DCHECK(item.is_transfer_node) DCHECK(item.is_transfer_node)
<< item.kernel->name() << " - input " << i; << item.kernel->name() << " - input " << i;
entry->state = Entry::State::HAS_CONST_TENSOR; entry->state = Entry::State::HAS_CONST_TENSOR;
@ -846,17 +847,18 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
} }
case Entry::State::HAS_VALUE: { case Entry::State::HAS_VALUE: {
if (expect_ref) { if (TF_PREDICT_FALSE(expect_ref)) {
return AttachDef( return AttachDef(
errors::InvalidArgument(i, "-th input expects a ref type"), errors::InvalidArgument(i, "-th input expects a ref type"),
item.kernel->def()); item.kernel->def());
} }
inp->mutex_if_ref = nullptr;
inp->tensor = entry->val.get(); inp->tensor = entry->val.get();
break; break;
} }
case Entry::State::HAS_CONST_TENSOR: { case Entry::State::HAS_CONST_TENSOR: {
if (expect_ref) { if (TF_PREDICT_FALSE(expect_ref)) {
return AttachDef( return AttachDef(
errors::InvalidArgument(i, "-th input expects a ref type"), errors::InvalidArgument(i, "-th input expects a ref type"),
item.kernel->def()); item.kernel->def());
@ -865,6 +867,7 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
// stores a non-const `Tensor*`, and relies on the `OpKernelContext` // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
// accessors making dynamic checks that prevent using an immutable // accessors making dynamic checks that prevent using an immutable
// tensor as a mutable tensor. // tensor as a mutable tensor.
inp->mutex_if_ref = nullptr;
inp->tensor = const_cast<Tensor*>(entry->const_tensor); inp->tensor = const_cast<Tensor*>(entry->const_tensor);
break; break;
} }
@ -872,8 +875,8 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
case Entry::State::HAS_REF_TENSOR: { case Entry::State::HAS_REF_TENSOR: {
{ {
tf_shared_lock ml(*entry->ref_tensor.mu); tf_shared_lock ml(*entry->ref_tensor.mu);
if (!entry->ref_tensor.tensor->IsInitialized() && if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
!item.is_initialization_op) { !item.is_initialization_op)) {
return AttachDef(errors::FailedPrecondition( return AttachDef(errors::FailedPrecondition(
"Attempting to use uninitialized value ", "Attempting to use uninitialized value ",
item.kernel->requested_input(i)), item.kernel->requested_input(i)),
@ -896,12 +899,13 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
} }
entry->state = Entry::State::HAS_VALUE; entry->state = Entry::State::HAS_VALUE;
inp->mutex_if_ref = nullptr;
inp->tensor = entry->val.get(); inp->tensor = entry->val.get();
// The dtype of entry->ref_tensor.tensor could have been changed by // The dtype of entry->ref_tensor.tensor could have been changed by
// another operation that ran after the operation that "produced" it // another operation that ran after the operation that "produced" it
// executed, so re-validate that the type of the dereferenced tensor // executed, so re-validate that the type of the dereferenced tensor
// matches the expected input type. // matches the expected input type.
if (item.input_type(i) != inp->tensor->dtype()) { if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
return AttachDef( return AttachDef(
errors::InvalidArgument( errors::InvalidArgument(
i, "-th input expects type ", i, "-th input expects type ",

View File

@ -191,9 +191,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8 DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
uint8* input_types = item->input_type_base(); uint8* input_types = item->input_type_base();
item->is_any_input_ref_typed = false;
for (int i = 0; i < num_inputs; i++) { for (int i = 0; i < num_inputs; i++) {
input_types[i] = static_cast<uint8>(n->input_type(i)); input_types[i] = static_cast<uint8>(n->input_type(i));
DCHECK_EQ(item->input_type(i), n->input_type(i)); DCHECK_EQ(item->input_type(i), n->input_type(i));
item->is_any_input_ref_typed |= IsRefType(n->input_type(i));
} }
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types. // Check ScopedAllocatorAttrs and forward_from. Also assign output_types.

View File

@ -81,6 +81,8 @@ struct NodeItem {
// of any output edge is a // of any output edge is a
// merge or control trigger // merge or control trigger
// node. // node.
bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this
// node's input types.
// The kernel for this node. // The kernel for this node.
OpKernel* kernel = nullptr; OpKernel* kernel = nullptr;