[Executor] Avoid unnecessary NodeItem::input_type()
calls in PrepareInputs()
.
We currently unconditionally read the input type of all inputs in order to handle the increasingly rare (and deprecated!) reference-typed input case. This change caches whether there are any ref-typed inputs to a node in a bit in the fixed-length part of a `NodeItem`, and only evaluates the input type if that bit is true. In addition, this change avoids calling `InlinedVector<TensorValue>::clear()` before calling `InlinedVector<TensorValue>::resize()`. Because we overwrite all of the values in the vector, there is no need to clear it before resizing. In some cases this can avoid a free/malloc of the underlying vector storage. PiperOrigin-RevId: 310572347 Change-Id: Ie5be8eb32fd4eba522f3b661cf9f5099d5263c6f
This commit is contained in:
parent
56a99d6c28
commit
4cb8816479
@ -811,16 +811,14 @@ template <class PropagatorStateType>
|
|||||||
Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||||
const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
|
const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
|
||||||
AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
|
AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
|
||||||
inputs->clear();
|
|
||||||
inputs->resize(item.num_inputs);
|
inputs->resize(item.num_inputs);
|
||||||
input_alloc_attrs->clear();
|
|
||||||
input_alloc_attrs->resize(item.num_inputs);
|
input_alloc_attrs->resize(item.num_inputs);
|
||||||
|
|
||||||
*is_input_dead = false;
|
*is_input_dead = false;
|
||||||
|
|
||||||
bool is_merge = item.is_merge;
|
|
||||||
for (int i = 0; i < item.num_inputs; ++i) {
|
for (int i = 0; i < item.num_inputs; ++i) {
|
||||||
const bool expect_ref = IsRefType(item.input_type(i));
|
const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
|
||||||
|
IsRefType(item.input_type(i));
|
||||||
Entry* entry = first_input + i;
|
Entry* entry = first_input + i;
|
||||||
(*input_alloc_attrs)[i] = entry->alloc_attr;
|
(*input_alloc_attrs)[i] = entry->alloc_attr;
|
||||||
|
|
||||||
@ -830,7 +828,10 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
|||||||
switch (entry->state) {
|
switch (entry->state) {
|
||||||
case Entry::State::NO_VALUE: {
|
case Entry::State::NO_VALUE: {
|
||||||
// Only merge and transfer nodes can have no-value inputs.
|
// Only merge and transfer nodes can have no-value inputs.
|
||||||
if (!is_merge) {
|
inp->mutex_if_ref = nullptr;
|
||||||
|
if (item.is_merge) {
|
||||||
|
inp->tensor = nullptr;
|
||||||
|
} else {
|
||||||
DCHECK(item.is_transfer_node)
|
DCHECK(item.is_transfer_node)
|
||||||
<< item.kernel->name() << " - input " << i;
|
<< item.kernel->name() << " - input " << i;
|
||||||
entry->state = Entry::State::HAS_CONST_TENSOR;
|
entry->state = Entry::State::HAS_CONST_TENSOR;
|
||||||
@ -846,17 +847,18 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
case Entry::State::HAS_VALUE: {
|
case Entry::State::HAS_VALUE: {
|
||||||
if (expect_ref) {
|
if (TF_PREDICT_FALSE(expect_ref)) {
|
||||||
return AttachDef(
|
return AttachDef(
|
||||||
errors::InvalidArgument(i, "-th input expects a ref type"),
|
errors::InvalidArgument(i, "-th input expects a ref type"),
|
||||||
item.kernel->def());
|
item.kernel->def());
|
||||||
}
|
}
|
||||||
|
inp->mutex_if_ref = nullptr;
|
||||||
inp->tensor = entry->val.get();
|
inp->tensor = entry->val.get();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case Entry::State::HAS_CONST_TENSOR: {
|
case Entry::State::HAS_CONST_TENSOR: {
|
||||||
if (expect_ref) {
|
if (TF_PREDICT_FALSE(expect_ref)) {
|
||||||
return AttachDef(
|
return AttachDef(
|
||||||
errors::InvalidArgument(i, "-th input expects a ref type"),
|
errors::InvalidArgument(i, "-th input expects a ref type"),
|
||||||
item.kernel->def());
|
item.kernel->def());
|
||||||
@ -865,6 +867,7 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
|||||||
// stores a non-const `Tensor*`, and relies on the `OpKernelContext`
|
// stores a non-const `Tensor*`, and relies on the `OpKernelContext`
|
||||||
// accessors making dynamic checks that prevent using an immutable
|
// accessors making dynamic checks that prevent using an immutable
|
||||||
// tensor as a mutable tensor.
|
// tensor as a mutable tensor.
|
||||||
|
inp->mutex_if_ref = nullptr;
|
||||||
inp->tensor = const_cast<Tensor*>(entry->const_tensor);
|
inp->tensor = const_cast<Tensor*>(entry->const_tensor);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -872,8 +875,8 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
|||||||
case Entry::State::HAS_REF_TENSOR: {
|
case Entry::State::HAS_REF_TENSOR: {
|
||||||
{
|
{
|
||||||
tf_shared_lock ml(*entry->ref_tensor.mu);
|
tf_shared_lock ml(*entry->ref_tensor.mu);
|
||||||
if (!entry->ref_tensor.tensor->IsInitialized() &&
|
if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
|
||||||
!item.is_initialization_op) {
|
!item.is_initialization_op)) {
|
||||||
return AttachDef(errors::FailedPrecondition(
|
return AttachDef(errors::FailedPrecondition(
|
||||||
"Attempting to use uninitialized value ",
|
"Attempting to use uninitialized value ",
|
||||||
item.kernel->requested_input(i)),
|
item.kernel->requested_input(i)),
|
||||||
@ -896,12 +899,13 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
|||||||
}
|
}
|
||||||
entry->state = Entry::State::HAS_VALUE;
|
entry->state = Entry::State::HAS_VALUE;
|
||||||
|
|
||||||
|
inp->mutex_if_ref = nullptr;
|
||||||
inp->tensor = entry->val.get();
|
inp->tensor = entry->val.get();
|
||||||
// The dtype of entry->ref_tensor.tensor could have been changed by
|
// The dtype of entry->ref_tensor.tensor could have been changed by
|
||||||
// another operation that ran after the operation that "produced" it
|
// another operation that ran after the operation that "produced" it
|
||||||
// executed, so re-validate that the type of the dereferenced tensor
|
// executed, so re-validate that the type of the dereferenced tensor
|
||||||
// matches the expected input type.
|
// matches the expected input type.
|
||||||
if (item.input_type(i) != inp->tensor->dtype()) {
|
if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
|
||||||
return AttachDef(
|
return AttachDef(
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
i, "-th input expects type ",
|
i, "-th input expects type ",
|
||||||
|
@ -191,9 +191,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
|||||||
|
|
||||||
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
|
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
|
||||||
uint8* input_types = item->input_type_base();
|
uint8* input_types = item->input_type_base();
|
||||||
|
item->is_any_input_ref_typed = false;
|
||||||
for (int i = 0; i < num_inputs; i++) {
|
for (int i = 0; i < num_inputs; i++) {
|
||||||
input_types[i] = static_cast<uint8>(n->input_type(i));
|
input_types[i] = static_cast<uint8>(n->input_type(i));
|
||||||
DCHECK_EQ(item->input_type(i), n->input_type(i));
|
DCHECK_EQ(item->input_type(i), n->input_type(i));
|
||||||
|
item->is_any_input_ref_typed |= IsRefType(n->input_type(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
|
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
|
||||||
|
@ -81,6 +81,8 @@ struct NodeItem {
|
|||||||
// of any output edge is a
|
// of any output edge is a
|
||||||
// merge or control trigger
|
// merge or control trigger
|
||||||
// node.
|
// node.
|
||||||
|
bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this
|
||||||
|
// node's input types.
|
||||||
|
|
||||||
// The kernel for this node.
|
// The kernel for this node.
|
||||||
OpKernel* kernel = nullptr;
|
OpKernel* kernel = nullptr;
|
||||||
|
Loading…
Reference in New Issue
Block a user