[Executor] Avoid unnecessary NodeItem::input_type()
calls in PrepareInputs()
.
We currently unconditionally read the input type of all inputs in order to handle the increasingly rare (and deprecated!) reference-typed input case. This change caches whether there are any ref-typed inputs to a node in a bit in the fixed-length part of a `NodeItem`, and only evaluates the input type if that bit is true. In addition, this change avoids calling `InlinedVector<TensorValue>::clear()` before calling `InlinedVector<TensorValue>::resize()`. Because we overwrite all of the values in the vector, there is no need to clear it before resizing. In some cases this can avoid a free/malloc of the underlying vector storage. PiperOrigin-RevId: 310572347 Change-Id: Ie5be8eb32fd4eba522f3b661cf9f5099d5263c6f
This commit is contained in:
parent
56a99d6c28
commit
4cb8816479
tensorflow/core/common_runtime
@ -811,16 +811,14 @@ template <class PropagatorStateType>
|
||||
Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||
const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
|
||||
AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
|
||||
inputs->clear();
|
||||
inputs->resize(item.num_inputs);
|
||||
input_alloc_attrs->clear();
|
||||
input_alloc_attrs->resize(item.num_inputs);
|
||||
|
||||
*is_input_dead = false;
|
||||
|
||||
bool is_merge = item.is_merge;
|
||||
for (int i = 0; i < item.num_inputs; ++i) {
|
||||
const bool expect_ref = IsRefType(item.input_type(i));
|
||||
const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
|
||||
IsRefType(item.input_type(i));
|
||||
Entry* entry = first_input + i;
|
||||
(*input_alloc_attrs)[i] = entry->alloc_attr;
|
||||
|
||||
@ -830,7 +828,10 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||
switch (entry->state) {
|
||||
case Entry::State::NO_VALUE: {
|
||||
// Only merge and transfer nodes can have no-value inputs.
|
||||
if (!is_merge) {
|
||||
inp->mutex_if_ref = nullptr;
|
||||
if (item.is_merge) {
|
||||
inp->tensor = nullptr;
|
||||
} else {
|
||||
DCHECK(item.is_transfer_node)
|
||||
<< item.kernel->name() << " - input " << i;
|
||||
entry->state = Entry::State::HAS_CONST_TENSOR;
|
||||
@ -846,17 +847,18 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||
}
|
||||
|
||||
case Entry::State::HAS_VALUE: {
|
||||
if (expect_ref) {
|
||||
if (TF_PREDICT_FALSE(expect_ref)) {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument(i, "-th input expects a ref type"),
|
||||
item.kernel->def());
|
||||
}
|
||||
inp->mutex_if_ref = nullptr;
|
||||
inp->tensor = entry->val.get();
|
||||
break;
|
||||
}
|
||||
|
||||
case Entry::State::HAS_CONST_TENSOR: {
|
||||
if (expect_ref) {
|
||||
if (TF_PREDICT_FALSE(expect_ref)) {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument(i, "-th input expects a ref type"),
|
||||
item.kernel->def());
|
||||
@ -865,6 +867,7 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||
// stores a non-const `Tensor*`, and relies on the `OpKernelContext`
|
||||
// accessors making dynamic checks that prevent using an immutable
|
||||
// tensor as a mutable tensor.
|
||||
inp->mutex_if_ref = nullptr;
|
||||
inp->tensor = const_cast<Tensor*>(entry->const_tensor);
|
||||
break;
|
||||
}
|
||||
@ -872,8 +875,8 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||
case Entry::State::HAS_REF_TENSOR: {
|
||||
{
|
||||
tf_shared_lock ml(*entry->ref_tensor.mu);
|
||||
if (!entry->ref_tensor.tensor->IsInitialized() &&
|
||||
!item.is_initialization_op) {
|
||||
if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
|
||||
!item.is_initialization_op)) {
|
||||
return AttachDef(errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized value ",
|
||||
item.kernel->requested_input(i)),
|
||||
@ -896,12 +899,13 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs(
|
||||
}
|
||||
entry->state = Entry::State::HAS_VALUE;
|
||||
|
||||
inp->mutex_if_ref = nullptr;
|
||||
inp->tensor = entry->val.get();
|
||||
// The dtype of entry->ref_tensor.tensor could have been changed by
|
||||
// another operation that ran after the operation that "produced" it
|
||||
// executed, so re-validate that the type of the dereferenced tensor
|
||||
// matches the expected input type.
|
||||
if (item.input_type(i) != inp->tensor->dtype()) {
|
||||
if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument(
|
||||
i, "-th input expects type ",
|
||||
|
@ -191,9 +191,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
||||
|
||||
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
|
||||
uint8* input_types = item->input_type_base();
|
||||
item->is_any_input_ref_typed = false;
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
input_types[i] = static_cast<uint8>(n->input_type(i));
|
||||
DCHECK_EQ(item->input_type(i), n->input_type(i));
|
||||
item->is_any_input_ref_typed |= IsRefType(n->input_type(i));
|
||||
}
|
||||
|
||||
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
|
||||
|
@ -81,6 +81,8 @@ struct NodeItem {
|
||||
// of any output edge is a
|
||||
// merge or control trigger
|
||||
// node.
|
||||
bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this
|
||||
// node's input types.
|
||||
|
||||
// The kernel for this node.
|
||||
OpKernel* kernel = nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user