diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 1f2a364258f..74de6b28d3f 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -811,16 +811,14 @@ template <class PropagatorStateType> Status ExecutorState<PropagatorStateType>::PrepareInputs( const NodeItem& item, Entry* first_input, TensorValueVec* inputs, AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { - inputs->clear(); inputs->resize(item.num_inputs); - input_alloc_attrs->clear(); input_alloc_attrs->resize(item.num_inputs); *is_input_dead = false; - bool is_merge = item.is_merge; for (int i = 0; i < item.num_inputs; ++i) { - const bool expect_ref = IsRefType(item.input_type(i)); + const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) && + IsRefType(item.input_type(i)); Entry* entry = first_input + i; (*input_alloc_attrs)[i] = entry->alloc_attr; @@ -830,7 +828,10 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs( switch (entry->state) { case Entry::State::NO_VALUE: { // Only merge and transfer nodes can have no-value inputs. - if (!is_merge) { + inp->mutex_if_ref = nullptr; + if (item.is_merge) { + inp->tensor = nullptr; + } else { DCHECK(item.is_transfer_node) << item.kernel->name() << " - input " << i; entry->state = Entry::State::HAS_CONST_TENSOR; @@ -846,17 +847,18 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs( } case Entry::State::HAS_VALUE: { - if (expect_ref) { + if (TF_PREDICT_FALSE(expect_ref)) { return AttachDef( errors::InvalidArgument(i, "-th input expects a ref type"), item.kernel->def()); } + inp->mutex_if_ref = nullptr; inp->tensor = entry->val.get(); break; } case Entry::State::HAS_CONST_TENSOR: { - if (expect_ref) { + if (TF_PREDICT_FALSE(expect_ref)) { return AttachDef( errors::InvalidArgument(i, "-th input expects a ref type"), item.kernel->def()); @@ -865,6 +867,7 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs( // stores a non-const `Tensor*`, and relies on the `OpKernelContext` // accessors making dynamic checks that prevent using an immutable // tensor as a mutable tensor. + inp->mutex_if_ref = nullptr; inp->tensor = const_cast<Tensor*>(entry->const_tensor); break; } @@ -872,8 +875,8 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs( case Entry::State::HAS_REF_TENSOR: { { tf_shared_lock ml(*entry->ref_tensor.mu); - if (!entry->ref_tensor.tensor->IsInitialized() && - !item.is_initialization_op) { + if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() && + !item.is_initialization_op)) { return AttachDef(errors::FailedPrecondition( "Attempting to use uninitialized value ", item.kernel->requested_input(i)), @@ -896,12 +899,13 @@ Status ExecutorState<PropagatorStateType>::PrepareInputs( } entry->state = Entry::State::HAS_VALUE; + inp->mutex_if_ref = nullptr; inp->tensor = entry->val.get(); // The dtype of entry->ref_tensor.tensor could have been changed by // another operation that ran after the operation that "produced" it // executed, so re-validate that the type of the dereferenced tensor // matches the expected input type. - if (item.input_type(i) != inp->tensor->dtype()) { + if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) { return AttachDef( errors::InvalidArgument( i, "-th input expects type ", diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 7db0781551d..7a63e06814a 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -191,9 +191,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { DCHECK_LT(DataType_MAX, 255); // Must fit in uint8 uint8* input_types = item->input_type_base(); + item->is_any_input_ref_typed = false; for (int i = 0; i < num_inputs; i++) { input_types[i] = static_cast<uint8>(n->input_type(i)); DCHECK_EQ(item->input_type(i), n->input_type(i)); + item->is_any_input_ref_typed |= IsRefType(n->input_type(i)); } // Check ScopedAllocatorAttrs and forward_from. Also assign output_types. diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index 6d31555ed9a..38eb3e33bcb 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -81,6 +81,8 @@ struct NodeItem { // of any output edge is a // merge or control trigger // node. + bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this + // node's input types. // The kernel for this node. OpKernel* kernel = nullptr;